diff --git a/.golangci.yml b/.golangci.yml index 5903fa7b..32298872 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -63,6 +63,8 @@ linters-settings: exclude-generated: true severity: "low" confidence: "low" + nolintlint: + allow-unused: true run: timeout: 5m diff --git a/README.md b/README.md index b91777a1..7351b3ec 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,34 @@ To disable Cloud Fetch (e.g., when handling smaller datasets or to avoid additio token:[your token]@[Workspace hostname]:[Port number][Endpoint HTTP Path]?useCloudFetch=false ``` +### Telemetry Configuration (Optional) + +The driver includes optional telemetry to help improve performance and reliability. Telemetry is **disabled by default** and requires explicit opt-in. + +**Opt-in to telemetry** (respects server-side feature flags): +``` +token:[your token]@[Workspace hostname]:[Port number][Endpoint HTTP Path]?enableTelemetry=true +``` + +**Opt-out of telemetry** (explicitly disable): +``` +token:[your token]@[Workspace hostname]:[Port number][Endpoint HTTP Path]?enableTelemetry=false +``` + +**What data is collected:** +- ✅ Query latency and performance metrics +- ✅ Error codes (not error messages) +- ✅ Feature usage (CloudFetch, LZ4, etc.) +- ✅ Driver version and environment info + +**What is NOT collected:** +- ❌ SQL query text +- ❌ Query results or data values +- ❌ Table/column names +- ❌ User identities or credentials + +Telemetry has < 1% performance overhead and uses circuit breaker protection to ensure it never impacts your queries. For more details, see `telemetry/DESIGN.md` and `telemetry/TROUBLESHOOTING.md`. + ### Connecting with a new Connector You can also connect with a new connector object. For example: diff --git a/auth/oauth/u2m/authenticator.go b/auth/oauth/u2m/authenticator.go index ba9d4d72..456fba6d 100644 --- a/auth/oauth/u2m/authenticator.go +++ b/auth/oauth/u2m/authenticator.go @@ -40,16 +40,17 @@ func NewAuthenticator(hostName string, timeout time.Duration) (auth.Authenticato cloud := oauth.InferCloudFromHost(hostName) var clientID, redirectURL string - if cloud == oauth.AWS { + switch cloud { + case oauth.AWS: clientID = awsClientId redirectURL = awsRedirectURL - } else if cloud == oauth.Azure { + case oauth.Azure: clientID = azureClientId redirectURL = azureRedirectURL - } else if cloud == oauth.GCP { + case oauth.GCP: clientID = gcpClientId redirectURL = gcpRedirectURL - } else { + default: return nil, errors.New("unhandled cloud type: " + cloud.String()) } @@ -147,14 +148,14 @@ func (tsp *tokenSourceProvider) GetTokenSource() (oauth2.TokenSource, error) { if err != nil { return nil, err } - defer listener.Close() + defer listener.Close() //nolint:errcheck srv := &http.Server{ ReadHeaderTimeout: 3 * time.Second, WriteTimeout: 30 * time.Second, } - defer srv.Close() + defer srv.Close() //nolint:errcheck // Start local server to wait for callback go func() { @@ -209,7 +210,7 @@ func (tsp *tokenSourceProvider) ServeHTTP(w http.ResponseWriter, r *http.Request if resp.err != "" { log.Error().Msg(resp.err) w.WriteHeader(http.StatusBadRequest) - _, err := w.Write([]byte(errorHTML("Identity Provider returned an error: " + resp.err))) + _, err := w.Write([]byte(errorHTML("Identity Provider returned an error: " + resp.err))) //nolint:gosec // XSS not a concern for local OAuth callback if err != nil { log.Error().Err(err).Msg("unable to write error response") } diff --git a/auth/tokenprovider/exchange.go b/auth/tokenprovider/exchange.go index bf494634..09c9268d 100644 --- a/auth/tokenprovider/exchange.go +++ b/auth/tokenprovider/exchange.go @@ -138,7 +138,7 @@ func (p *FederationProvider) tryTokenExchange(ctx context.Context, subjectToken } // Create request - req, err := http.NewRequestWithContext(ctx, "POST", exchangeURL, strings.NewReader(data.Encode())) + req, err := http.NewRequestWithContext(ctx, "POST", exchangeURL, strings.NewReader(data.Encode())) //nolint:gosec // URL is from trusted config if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } @@ -147,11 +147,11 @@ func (p *FederationProvider) tryTokenExchange(ctx context.Context, subjectToken req.Header.Set("Accept", "*/*") // Make request - resp, err := p.httpClient.Do(req) + resp, err := p.httpClient.Do(req) //nolint:gosec // G704: URL is from trusted configuration if err != nil { return nil, fmt.Errorf("request failed: %w", err) } - defer resp.Body.Close() + defer resp.Body.Close() //nolint:errcheck body, err := io.ReadAll(resp.Body) if err != nil { diff --git a/auth/tokenprovider/federation_test.go b/auth/tokenprovider/federation_test.go index de49da04..9c0ccfb1 100644 --- a/auth/tokenprovider/federation_test.go +++ b/auth/tokenprovider/federation_test.go @@ -108,7 +108,8 @@ func TestFederationProvider_TokenExchangeSuccess(t *testing.T) { assert.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type")) assert.Equal(t, "*/*", r.Header.Get("Accept")) - // Parse form data + // Parse form data - limit body size to prevent G120 + r.Body = http.MaxBytesReader(w, r.Body, 1<<20) err := r.ParseForm() require.NoError(t, err) @@ -155,13 +156,14 @@ func TestFederationProvider_TokenExchangeWithClientID(t *testing.T) { // Create mock server that checks for client_id server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, 1<<20) err := r.ParseForm() require.NoError(t, err) // Verify client_id is present assert.Equal(t, clientID, r.FormValue("client_id")) - response := map[string]interface{}{ + response := map[string]interface{}{ //nolint:gosec // G101: test token, not a real credential "access_token": "sp-wide-federation-token", "token_type": "Bearer", "expires_in": 3600, diff --git a/auth/tokenprovider/provider_test.go b/auth/tokenprovider/provider_test.go index e3df4753..0ee8c62e 100644 --- a/auth/tokenprovider/provider_test.go +++ b/auth/tokenprovider/provider_test.go @@ -146,7 +146,7 @@ func TestExternalTokenProvider(t *testing.T) { callCount := 0 tokenFunc := func() (string, error) { callCount++ - return "external-token-" + string(rune(callCount)), nil + return "external-token-" + string(rune(callCount)), nil //nolint:gosec // G115: test counter, values are always small } provider := NewExternalTokenProvider(tokenFunc) @@ -211,7 +211,7 @@ func TestExternalTokenProvider(t *testing.T) { counter := 0 tokenFunc := func() (string, error) { counter++ - return "token-" + string(rune(counter)), nil + return "token-" + string(rune(counter)), nil //nolint:gosec // G115: test counter, values are always small } provider := NewExternalTokenProvider(tokenFunc) diff --git a/connection.go b/connection.go index 01e5e8cb..484fae11 100644 --- a/connection.go +++ b/connection.go @@ -52,16 +52,19 @@ func (c *conn) Close() error { log := logger.WithContext(c.id, "", "") ctx := driverctx.NewContextWithConnId(context.Background(), c.id) - // Close telemetry and release resources + // Time CloseSession so we can record DELETE_SESSION before flushing telemetry + closeStart := time.Now() + _, err := c.client.CloseSession(ctx, &cli_service.TCloseSessionReq{ + SessionHandle: c.session.SessionHandle, + }) + + // Record DELETE_SESSION regardless of error (matches JDBC), then flush and release if c.telemetry != nil { + c.telemetry.RecordOperation(ctx, c.id, telemetry.OperationTypeDeleteSession, time.Since(closeStart).Milliseconds(), err) _ = c.telemetry.Close(ctx) telemetry.ReleaseForConnection(c.cfg.Host) } - _, err := c.client.CloseSession(ctx, &cli_service.TCloseSessionReq{ - SessionHandle: c.session.SessionHandle, - }) - if err != nil { log.Err(err).Msg("databricks: failed to close connection") return dbsqlerrint.NewBadConnectionError(err) @@ -93,7 +96,7 @@ func (c *conn) Ping(ctx context.Context) error { log.Err(err).Msg("databricks: failed to ping") return dbsqlerrint.NewBadConnectionError(err) } - defer rows.Close() + defer rows.Close() //nolint:errcheck log.Debug().Msg("databricks: ping successful") return nil @@ -155,9 +158,13 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name alreadyClosed := exStmtResp.DirectResults != nil && exStmtResp.DirectResults.CloseOperation != nil newCtx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), c.id), corrId) if !alreadyClosed && (opStatusResp == nil || opStatusResp.GetOperationState() != cli_service.TOperationState_CLOSED_STATE) { + closeOpStart := time.Now() _, err1 := c.client.CloseOperation(newCtx, &cli_service.TCloseOperationReq{ OperationHandle: exStmtResp.OperationHandle, }) + if c.telemetry != nil { + c.telemetry.RecordOperation(ctx, c.id, telemetry.OperationTypeCloseStatement, time.Since(closeOpStart).Milliseconds(), err1) + } if err1 != nil { log.Err(err1).Msg("databricks: failed to close operation after executing statement") closeOpErr = err1 // Capture for telemetry @@ -216,7 +223,15 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam return nil, dbsqlerrint.NewExecutionError(ctx, dbsqlerr.ErrQueryExecution, err, opStatusResp) } - rows, err := rows.NewRows(ctx, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults) + // Telemetry callback for tracking row fetching metrics + telemetryUpdate := func(chunkCount int, bytesDownloaded int64) { + if c.telemetry != nil { + c.telemetry.AddTag(ctx, "chunk_count", chunkCount) + c.telemetry.AddTag(ctx, "bytes_downloaded", bytesDownloaded) + } + } + + rows, err := rows.NewRows(ctx, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults, telemetryUpdate) return rows, err } @@ -381,7 +396,14 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver } } + executeStart := time.Now() resp, err := c.client.ExecuteStatement(ctx, &req) + // Record the Thrift call latency as a separate operation metric. + // This is distinct from the statement-level metric (BeforeExecuteWithTime), which + // measures end-to-end latency including polling and row fetching. + if c.telemetry != nil { + c.telemetry.RecordOperation(ctx, c.id, telemetry.OperationTypeExecuteStatement, time.Since(executeStart).Milliseconds(), err) + } var log *logger.DBSQLLogger log, ctx = client.LoggerAndContext(ctx, resp) @@ -514,11 +536,11 @@ func (c *conn) handleStagingPut(ctx context.Context, presignedUrl string, header } client := &http.Client{} - dat, err := os.Open(localFile) + dat, err := os.Open(localFile) //nolint:gosec // localFile is provided by the application, not user input if err != nil { return dbsqlerrint.NewDriverError(ctx, "error reading local file", err) } - defer dat.Close() + defer dat.Close() //nolint:errcheck info, err := dat.Stat() if err != nil { @@ -535,7 +557,7 @@ func (c *conn) handleStagingPut(ctx context.Context, presignedUrl string, header if err != nil { return dbsqlerrint.NewDriverError(ctx, "error sending http request", err) } - defer res.Body.Close() + defer res.Body.Close() //nolint:errcheck content, err := io.ReadAll(res.Body) if err != nil || !Succeeded(res) { @@ -559,7 +581,7 @@ func (c *conn) handleStagingGet(ctx context.Context, presignedUrl string, header if err != nil { return dbsqlerrint.NewDriverError(ctx, "error sending http request", err) } - defer res.Body.Close() + defer res.Body.Close() //nolint:errcheck content, err := io.ReadAll(res.Body) if err != nil || !Succeeded(res) { @@ -583,7 +605,7 @@ func (c *conn) handleStagingRemove(ctx context.Context, presignedUrl string, hea if err != nil { return dbsqlerrint.NewDriverError(ctx, "error sending http request", err) } - defer res.Body.Close() + defer res.Body.Close() //nolint:errcheck content, err := io.ReadAll(res.Body) if err != nil || !Succeeded(res) { @@ -646,11 +668,18 @@ func (c *conn) execStagingOperation( } if len(driverctx.StagingPathsFromContext(ctx)) != 0 { - row, err = rows.NewRows(ctx, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults) + // Telemetry callback for staging operation row fetching + telemetryUpdate := func(chunkCount int, bytesDownloaded int64) { + if c.telemetry != nil { + c.telemetry.AddTag(ctx, "chunk_count", chunkCount) + c.telemetry.AddTag(ctx, "bytes_downloaded", bytesDownloaded) + } + } + row, err = rows.NewRows(ctx, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults, telemetryUpdate) if err != nil { return dbsqlerrint.NewDriverError(ctx, "error reading row.", err) } - defer row.Close() + defer row.Close() //nolint:errcheck } else { return dbsqlerrint.NewDriverError(ctx, "staging ctx must be provided.", nil) @@ -663,7 +692,7 @@ func (c *conn) execStagingOperation( if err != nil { return dbsqlerrint.NewDriverError(ctx, "error fetching staging operation results", err) } - var stringValues []string = make([]string, 4) + stringValues := make([]string, 4) for i, val := range sqlRow { // this will either be 3 (remove op) or 4 (put/get) elements if s, ok := val.(string); ok { stringValues[i] = s diff --git a/connection_test.go b/connection_test.go index a9a0e8b1..c4cb9f15 100644 --- a/connection_test.go +++ b/connection_test.go @@ -165,8 +165,8 @@ func TestConn_executeStatement(t *testing.T) { for _, opTest := range operationStateTests { closeOperationCount = 0 executeStatementCount = 0 - executeStatementResp.DirectResults.OperationStatus.OperationState = &opTest.state - executeStatementResp.DirectResults.OperationStatus.DisplayMessage = &opTest.err + executeStatementResp.DirectResults.OperationStatus.OperationState = &opTest.state //nolint:gosec // G601: pointer is used only within this loop iteration + executeStatementResp.DirectResults.OperationStatus.DisplayMessage = &opTest.err //nolint:gosec // G601: pointer is used only within this loop iteration _, err := testConn.ExecContext(context.Background(), "select 1", []driver.NamedValue{}) if opTest.err == "" { assert.NoError(t, err) diff --git a/connector.go b/connector.go index f5d33d37..3e3ad330 100644 --- a/connector.go +++ b/connector.go @@ -55,6 +55,8 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { } protocolVersion := int64(c.cfg.ThriftProtocolVersion) + + sessionStart := time.Now() session, err := tclient.OpenSession(ctx, &cli_service.TOpenSessionReq{ ClientProtocolI64: &protocolVersion, Configuration: sessionParams, @@ -64,6 +66,8 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { }, CanUseMultipleCatalogs: &c.cfg.CanUseMultipleCatalogs, }) + sessionLatencyMs := time.Since(sessionStart).Milliseconds() + if err != nil { return nil, dbsqlerrint.NewRequestError(ctx, fmt.Sprintf("error connecting: host=%s port=%d, httpPath=%s", c.cfg.Host, c.cfg.Port, c.cfg.HTTPPath), err) } @@ -80,11 +84,15 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { conn.telemetry = telemetry.InitializeForConnection( ctx, c.cfg.Host, + c.cfg.DriverVersion, c.client, c.cfg.EnableTelemetry, + c.cfg.TelemetryBatchSize, + c.cfg.TelemetryFlushInterval, ) if conn.telemetry != nil { log.Debug().Msg("telemetry initialized for connection") + conn.telemetry.RecordOperation(ctx, conn.id, telemetry.OperationTypeCreateSession, sessionLatencyMs, nil) } log.Info().Msgf("connect: host=%s port=%d httpPath=%s serverProtocolVersion=0x%X", c.cfg.Host, c.cfg.Port, c.cfg.HTTPPath, session.ServerProtocolVersion) @@ -290,8 +298,8 @@ func WithTransport(t http.RoundTripper) ConnOption { return func(c *config.Config) { c.Transport = t - if c.CloudFetchConfig.HTTPClient == nil { - c.CloudFetchConfig.HTTPClient = &http.Client{ + if c.HTTPClient == nil { + c.HTTPClient = &http.Client{ Transport: t, } } diff --git a/connector_test.go b/connector_test.go index c89b74e0..66f351c9 100644 --- a/connector_test.go +++ b/connector_test.go @@ -263,8 +263,8 @@ func TestNewConnector(t *testing.T) { coni, ok := con.(*connector) require.True(t, ok) - assert.NotNil(t, coni.cfg.CloudFetchConfig.HTTPClient) - assert.Equal(t, customTransport, coni.cfg.CloudFetchConfig.HTTPClient.Transport) + assert.NotNil(t, coni.cfg.HTTPClient) + assert.Equal(t, customTransport, coni.cfg.HTTPClient.Transport) }) } diff --git a/driver_e2e_test.go b/driver_e2e_test.go index bcb9ad94..fdd3a538 100644 --- a/driver_e2e_test.go +++ b/driver_e2e_test.go @@ -58,7 +58,7 @@ func TestWorkflowExample(t *testing.T) { ) require.NoError(t, err) db := sql.OpenDB(connector) - defer db.Close() + defer db.Close() //nolint:errcheck ogCtx := driverctx.NewContextWithCorrelationId(context.Background(), "workflow-example") @@ -271,7 +271,7 @@ func TestContextTimeoutExample(t *testing.T) { db, err := sql.Open("databricks", ts.URL+"/path") require.NoError(t, err) - defer db.Close() + defer db.Close() //nolint:errcheck ogCtx := driverctx.NewContextWithCorrelationId(context.Background(), "context-timeout-example") @@ -321,7 +321,7 @@ func TestRetries(t *testing.T) { db, err := sql.Open("databricks", fmt.Sprintf("%s/503-2-retries", ts.URL)) require.NoError(t, err) - defer db.Close() + defer db.Close() //nolint:errcheck state.executeStatementResp = cli_service.TExecuteStatementResp{} loadTestData(t, "ExecuteStatement1.json", &state.executeStatementResp) @@ -347,7 +347,7 @@ func TestRetries(t *testing.T) { db, err := sql.Open("databricks", fmt.Sprintf("%s/429-2-retries", ts.URL)) require.NoError(t, err) - defer db.Close() + defer db.Close() //nolint:errcheck state.executeStatementResp = cli_service.TExecuteStatementResp{} loadTestData(t, "ExecuteStatement1.json", &state.executeStatementResp) @@ -383,7 +383,7 @@ func TestRetries(t *testing.T) { ) require.NoError(t, err) db := sql.OpenDB(connector) - defer db.Close() + defer db.Close() //nolint:errcheck state.executeStatementResp = cli_service.TExecuteStatementResp{} loadTestData(t, "ExecuteStatement1.json", &state.executeStatementResp) @@ -418,7 +418,7 @@ func TestRetries(t *testing.T) { ) require.NoError(t, err) db := sql.OpenDB(connector) - defer db.Close() + defer db.Close() //nolint:errcheck state.executeStatementResp = cli_service.TExecuteStatementResp{} loadTestData(t, "ExecuteStatement1.json", &state.executeStatementResp) @@ -453,7 +453,7 @@ func TestRetries(t *testing.T) { ) require.NoError(t, err) db := sql.OpenDB(connector) - defer db.Close() + defer db.Close() //nolint:errcheck state.executeStatementResp = cli_service.TExecuteStatementResp{} loadTestData(t, "ExecuteStatement1.json", &state.executeStatementResp) @@ -479,7 +479,7 @@ func TestRetries(t *testing.T) { ) require.NoError(t, err) db2 := sql.OpenDB(connector2) - defer db.Close() + defer db.Close() //nolint:errcheck state.executeStatementResp = cli_service.TExecuteStatementResp{} loadTestData(t, "ExecuteStatement1.json", &state.executeStatementResp) diff --git a/examples/arrrowbatches/main.go b/examples/arrrowbatches/main.go index c01bdbc4..ef14c733 100644 --- a/examples/arrrowbatches/main.go +++ b/examples/arrrowbatches/main.go @@ -44,7 +44,7 @@ func main() { } db := sql.OpenDB(connector) - defer db.Close() + defer db.Close() //nolint:errcheck loopWithHasNext(db) loopWithNext(db) @@ -55,7 +55,7 @@ func loopWithHasNext(db *sql.DB) { defer cancel() conn, _ := db.Conn(ctx) - defer conn.Close() + defer conn.Close() //nolint:errcheck query := `select * from main.default.diamonds` @@ -69,7 +69,7 @@ func loopWithHasNext(db *sql.DB) { if err != nil { log.Fatalf("unable to run the query. err: %v", err) } - defer rows.Close() + defer rows.Close() //nolint:errcheck ctx2, cancel2 := context.WithTimeout(context.Background(), 30*time.Second) defer cancel2() @@ -99,7 +99,7 @@ func loopWithNext(db *sql.DB) { defer cancel() conn, _ := db.Conn(ctx) - defer conn.Close() + defer conn.Close() //nolint:errcheck query := `select * from main.default.diamonds` @@ -113,7 +113,7 @@ func loopWithNext(db *sql.DB) { if err != nil { log.Fatalf("unable to run the query. err: %v", err) } - defer rows.Close() + defer rows.Close() //nolint:errcheck ctx2, cancel2 := context.WithTimeout(context.Background(), 30*time.Second) defer cancel2() diff --git a/examples/browser_oauth_federation/main.go b/examples/browser_oauth_federation/main.go index b19302ad..449fe76f 100644 --- a/examples/browser_oauth_federation/main.go +++ b/examples/browser_oauth_federation/main.go @@ -56,7 +56,7 @@ func main() { } db := sql.OpenDB(connector) - defer db.Close() + defer db.Close() //nolint:errcheck // Test connection - this triggers browser OAuth flow fmt.Println("Connecting (browser will open for login)...") diff --git a/examples/catalog/main.go b/examples/catalog/main.go index dc5fbd45..dd4b4178 100644 --- a/examples/catalog/main.go +++ b/examples/catalog/main.go @@ -37,7 +37,7 @@ func main() { log.Fatal(err) } db := sql.OpenDB(connector) - defer db.Close() + defer db.Close() //nolint:errcheck ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -61,7 +61,7 @@ func main() { err := rows.Scan(&res) if err != nil { fmt.Println(err) - rows.Close() + rows.Close() //nolint:errcheck,gosec // G104: close in error path return } fmt.Println(res) diff --git a/examples/cloudfetch/main.go b/examples/cloudfetch/main.go index 04f67945..519f9e1f 100644 --- a/examples/cloudfetch/main.go +++ b/examples/cloudfetch/main.go @@ -46,7 +46,7 @@ func runTest(withCloudFetch bool, query string) ([]row, error) { return nil, err } db := sql.OpenDB(connector) - defer db.Close() + defer db.Close() //nolint:errcheck ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -54,8 +54,6 @@ func runTest(withCloudFetch bool, query string) ([]row, error) { return nil, err } rows, err1 := db.QueryContext(context.Background(), query) - defer rows.Close() - if err1 != nil { if err1 == sql.ErrNoRows { fmt.Println("not found") @@ -64,6 +62,7 @@ func runTest(withCloudFetch bool, query string) ([]row, error) { return nil, err } } + defer rows.Close() //nolint:errcheck var res []row for rows.Next() { r := row{} @@ -94,7 +93,7 @@ func main() { for i := 0; i < len(abRes); i++ { if abRes[i] != cfRes[i] { - log.Fatal(fmt.Sprintf("not equal for row: %d", i)) + log.Fatalf("not equal for row: %d", i) } } } diff --git a/examples/createdrop/main.go b/examples/createdrop/main.go index c6835eed..5b0709af 100644 --- a/examples/createdrop/main.go +++ b/examples/createdrop/main.go @@ -56,7 +56,7 @@ func main() { // Opening a driver typically will not attempt to connect to the database. db := sql.OpenDB(connector) // make sure to close it later - defer db.Close() + defer db.Close() //nolint:errcheck ogCtx := dbsqlctx.NewContextWithCorrelationId(context.Background(), "createdrop-example") diff --git a/examples/error/main.go b/examples/error/main.go index 9d34bf5c..8928a773 100644 --- a/examples/error/main.go +++ b/examples/error/main.go @@ -38,7 +38,7 @@ func main() { handleErr(err) db := sql.OpenDB(connector) - defer db.Close() + defer db.Close() //nolint:errcheck // test the connection err = db.Ping() @@ -95,7 +95,7 @@ func main() { } // At this point the query completed successfully - defer rows.Close() + defer rows.Close() //nolint:errcheck fmt.Printf("conn Id: %s, query Id: %s\n", connId, queryId) diff --git a/examples/ipcstreams/main.go b/examples/ipcstreams/main.go index c7b01da0..2cfc7d2c 100644 --- a/examples/ipcstreams/main.go +++ b/examples/ipcstreams/main.go @@ -37,13 +37,13 @@ func main() { } db := sql.OpenDB(connector) - defer db.Close() + defer db.Close() //nolint:errcheck ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() conn, _ := db.Conn(ctx) - defer conn.Close() + defer conn.Close() //nolint:errcheck query := `SELECT * FROM samples.nyctaxi.trips LIMIT 1000` @@ -57,7 +57,7 @@ func main() { if err != nil { log.Fatal("Failed to execute query: ", err) } - defer rows.Close() + defer rows.Close() //nolint:errcheck // Get the IPC stream iterator ipcStreams, err := rows.(dbsqlrows.Rows).GetArrowIPCStreams(ctx) diff --git a/examples/oauth/main.go b/examples/oauth/main.go index 7290a5a5..808472df 100644 --- a/examples/oauth/main.go +++ b/examples/oauth/main.go @@ -29,7 +29,7 @@ func testU2M() { authenticator, err := u2m.NewAuthenticator(os.Getenv("DATABRICKS_HOST"), 1*time.Minute) if err != nil { - log.Fatal(err.Error()) + log.Fatal(err.Error()) //nolint:gosec // G706: error tainted by env var, log injection not a concern here } connector, err := dbsql.NewConnector( @@ -42,7 +42,7 @@ func testU2M() { } db := sql.OpenDB(connector) - defer db.Close() + defer db.Close() //nolint:errcheck // Pinging should require logging in if err := db.Ping(); err != nil { @@ -91,7 +91,7 @@ func testM2M() { } db := sql.OpenDB(connector) - defer db.Close() + defer db.Close() //nolint:errcheck // Pinging should require logging in if err := db.Ping(); err != nil { diff --git a/examples/parameters/main.go b/examples/parameters/main.go index 869a4ede..f6ace61b 100644 --- a/examples/parameters/main.go +++ b/examples/parameters/main.go @@ -102,7 +102,7 @@ func main() { log.Fatal(err) } db := sql.OpenDB(connector) - defer db.Close() + defer db.Close() //nolint:errcheck queryWithNamedParameters(db) queryWithPositionalParameters(db) diff --git a/examples/query_tags/main.go b/examples/query_tags/main.go index 750a5ede..8ad7d3da 100644 --- a/examples/query_tags/main.go +++ b/examples/query_tags/main.go @@ -40,7 +40,7 @@ func main() { } db := sql.OpenDB(connector) - defer db.Close() + defer db.Close() //nolint:errcheck // Example 1: Connection-level query tags (set during connection) fmt.Println("=== Connection-level query tags ===") diff --git a/examples/queryrow/main.go b/examples/queryrow/main.go index dfd60d3d..0de5da07 100644 --- a/examples/queryrow/main.go +++ b/examples/queryrow/main.go @@ -35,7 +35,7 @@ func main() { log.Fatal(err) } db := sql.OpenDB(connector) - defer db.Close() + defer db.Close() //nolint:errcheck // ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) // defer cancel() ctx := context.Background() diff --git a/examples/queryrows/main.go b/examples/queryrows/main.go index e0f28daf..d7742e3b 100644 --- a/examples/queryrows/main.go +++ b/examples/queryrows/main.go @@ -35,7 +35,7 @@ func main() { log.Fatal(err) } db := sql.OpenDB(connector) - defer db.Close() + defer db.Close() //nolint:errcheck // ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) // defer cancel() @@ -80,7 +80,7 @@ func main() { err := rows.Scan(&res1, &res2) if err != nil { fmt.Println(err) - rows.Close() + rows.Close() //nolint:errcheck,gosec // G104: close in error path return } fmt.Printf("%v, %v\n", res1, res2) diff --git a/examples/staging/main.go b/examples/staging/main.go index 87ac148c..eb196b61 100644 --- a/examples/staging/main.go +++ b/examples/staging/main.go @@ -37,7 +37,7 @@ func main() { log.Fatal(err) } db := sql.OpenDB(connector) - defer db.Close() + defer db.Close() //nolint:errcheck // ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) // defer cancel() diff --git a/examples/timeout/main.go b/examples/timeout/main.go index 68ee7339..670fd7d3 100644 --- a/examples/timeout/main.go +++ b/examples/timeout/main.go @@ -31,7 +31,7 @@ func main() { // another initialization error. log.Fatal(err) } - defer db.Close() + defer db.Close() //nolint:errcheck ogCtx := driverctx.NewContextWithCorrelationId(context.Background(), "context-timeout-example") ctx1, cancel1 := context.WithTimeout(ogCtx, 10*time.Second) diff --git a/examples/timezone/main.go b/examples/timezone/main.go index 97665b05..5fadc119 100644 --- a/examples/timezone/main.go +++ b/examples/timezone/main.go @@ -43,7 +43,7 @@ func main() { log.Fatal(err) } db := sql.OpenDB(connector) - defer db.Close() + defer db.Close() //nolint:errcheck db.SetMaxOpenConns(1) // ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) @@ -84,7 +84,7 @@ func main() { log.Fatal(err1) } db1 := sql.OpenDB(connector1) - defer db1.Close() + defer db1.Close() //nolint:errcheck res1 := &time.Time{} res2 := &time.Time{} diff --git a/examples/token_federation/main.go b/examples/token_federation/main.go index 85f675e0..32fda024 100644 --- a/examples/token_federation/main.go +++ b/examples/token_federation/main.go @@ -44,7 +44,7 @@ func main() { case "custom": runCustomProviderExample() default: - log.Fatalf("Unknown example: %s (use: static, external, custom)", example) + log.Fatalf("Unknown example: %s (use: static, external, custom)", example) //nolint:gosec // G706: CLI arg, log injection not a concern } } @@ -143,7 +143,7 @@ func runCustomProviderExample() { // testConnection verifies the connection works func testConnection(connector driver.Connector) { db := sql.OpenDB(connector) - defer db.Close() + defer db.Close() //nolint:errcheck ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() diff --git a/examples/workflow/main.go b/examples/workflow/main.go index bd892946..3022ebaf 100644 --- a/examples/workflow/main.go +++ b/examples/workflow/main.go @@ -57,7 +57,7 @@ func main() { // Opening a driver typically will not attempt to connect to the database. db := sql.OpenDB(connector) // make sure to close it later - defer db.Close() + defer db.Close() //nolint:errcheck // the log looks like this: // ``` diff --git a/internal/cli_service/thrift_field_id_test.go b/internal/cli_service/thrift_field_id_test.go index 29a7fd0c..f70546aa 100644 --- a/internal/cli_service/thrift_field_id_test.go +++ b/internal/cli_service/thrift_field_id_test.go @@ -53,11 +53,11 @@ func TestThriftFieldIdsAreWithinAllowedRange(t *testing.T) { // validateThriftFieldIDs parses the cli_service.go file and extracts all thrift field IDs // to validate they are within the allowed range. func validateThriftFieldIDs(filePath string, maxAllowedFieldID int) ([]string, error) { - file, err := os.Open(filePath) + file, err := os.Open(filePath) //nolint:gosec // G304: path is a test fixture, not user-controlled if err != nil { return nil, fmt.Errorf("failed to open file %s: %w", filePath, err) } - defer file.Close() + defer file.Close() //nolint:errcheck var violations []string scanner := bufio.NewScanner(file) diff --git a/internal/client/client.go b/internal/client/client.go index b644c294..6aca7428 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -519,7 +519,7 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { if req.Body != nil { defer func() { if !reqBodyClosed { - req.Body.Close() + req.Body.Close() //nolint:errcheck,gosec // G104: close in deferred cleanup } }() } diff --git a/internal/config/config.go b/internal/config/config.go index 16260d3c..e5446ac8 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -102,6 +102,8 @@ type UserConfig struct { // Uses config overlay pattern: client > server > default. // Unset = check server feature flag; explicitly true/false overrides the server. EnableTelemetry ConfigValue[bool] + TelemetryBatchSize int // 0 = use default (100) + TelemetryFlushInterval time.Duration // 0 = use default (5s) Transport http.RoundTripper UseLz4Compression bool EnableMetricViewMetadata bool @@ -149,6 +151,8 @@ func (ucfg UserConfig) DeepCopy() UserConfig { EnableMetricViewMetadata: ucfg.EnableMetricViewMetadata, CloudFetchConfig: ucfg.CloudFetchConfig, EnableTelemetry: ucfg.EnableTelemetry, + TelemetryBatchSize: ucfg.TelemetryBatchSize, + TelemetryFlushInterval: ucfg.TelemetryFlushInterval, } } @@ -184,6 +188,9 @@ func (ucfg UserConfig) WithDefaults() UserConfig { ucfg.UseLz4Compression = false ucfg.CloudFetchConfig = CloudFetchConfig{}.WithDefaults() + // EnableTelemetry defaults to unset (ConfigValue zero value), + // meaning telemetry is controlled by server feature flags. + return ucfg } @@ -294,6 +301,19 @@ func ParseDSN(dsn string) (UserConfig, error) { } ucfg.EnableTelemetry = NewConfigValue(enableTelemetry) } + if batchSize, ok, err := params.extractAsInt("telemetry_batch_size"); ok { + if err != nil { + return UserConfig{}, err + } + if batchSize > 0 { + ucfg.TelemetryBatchSize = batchSize + } + } + if flushInterval, ok := params.extract("telemetry_flush_interval"); ok { + if d, err := time.ParseDuration(flushInterval); err == nil && d > 0 { + ucfg.TelemetryFlushInterval = d + } + } // for timezone we do a case insensitive key match. // We use getNoCase because we want to leave timezone in the params so that it will also diff --git a/internal/config/config_test.go b/internal/config/config_test.go index c2821c78..abea52b0 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -51,7 +51,7 @@ func TestParseConfig(t *testing.T) { }, { name: "with https scheme", - args: args{dsn: "https://token:supersecret@example.cloud.databricks.com:443/sql/1.0/endpoints/12346a5b5b0e123a"}, + args: args{dsn: "https://token:supersecret@example.cloud.databricks.com:443/sql/1.0/endpoints/12346a5b5b0e123a"}, //nolint:gosec // G101: test DSN with example password, not a real credential wantCfg: UserConfig{ Protocol: "https", Host: "example.cloud.databricks.com", diff --git a/internal/errors/err.go b/internal/errors/err.go index 26a642e3..7d6765d6 100644 --- a/internal/errors/err.go +++ b/internal/errors/err.go @@ -50,7 +50,7 @@ func newDatabricksError(ctx context.Context, msg string, err error) databricksEr // If the error chain contains an instance of retryableError // set the flag and retryAfter value. - var retryable bool = false + retryable := false var retryAfter time.Duration if errors.Is(err, RetryableError) { retryable = true diff --git a/internal/rows/arrowbased/arrowRecordIterator.go b/internal/rows/arrowbased/arrowRecordIterator.go index 787a0bab..d0b620db 100644 --- a/internal/rows/arrowbased/arrowRecordIterator.go +++ b/internal/rows/arrowbased/arrowRecordIterator.go @@ -82,7 +82,7 @@ func (ri *arrowRecordIterator) Close() { } if ri.resultPageIterator != nil { - ri.resultPageIterator.Close() + ri.resultPageIterator.Close() //nolint:errcheck,gosec // G104: close in cleanup } } } diff --git a/internal/rows/arrowbased/arrowRows.go b/internal/rows/arrowbased/arrowRows.go index 4e2cf802..c4a98e69 100644 --- a/internal/rows/arrowbased/arrowRows.go +++ b/internal/rows/arrowbased/arrowRows.go @@ -127,7 +127,7 @@ func NewArrowRowScanner(resultSetMetadata *cli_service.TGetResultSetMetadataResp return nil, err2 } - var location *time.Location = time.UTC + location := time.UTC if cfg != nil { if cfg.Location != nil { location = cfg.Location diff --git a/internal/rows/arrowbased/arrowRows_test.go b/internal/rows/arrowbased/arrowRows_test.go index 9d7eaba8..84b27426 100644 --- a/internal/rows/arrowbased/arrowRows_test.go +++ b/internal/rows/arrowbased/arrowRows_test.go @@ -239,7 +239,7 @@ func TestArrowRowScanner(t *testing.T) { d, _ := NewArrowRowScanner(metadataResp, rowSet, nil, nil, context.Background()) - var ars *arrowRowScanner = d.(*arrowRowScanner) + ars := d.(*arrowRowScanner) err := ars.makeColumnValuesContainers(ars, rowscanner.NewDelimiter(0, 1)) require.Nil(t, err) @@ -315,7 +315,7 @@ func TestArrowRowScanner(t *testing.T) { d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) - var ars *arrowRowScanner = d.(*arrowRowScanner) + ars := d.(*arrowRowScanner) err := ars.makeColumnValuesContainers(ars, rowscanner.NewDelimiter(0, 1)) require.Nil(t, err) @@ -417,7 +417,7 @@ func TestArrowRowScanner(t *testing.T) { d, err1 := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) require.Nil(t, err1) - var ars *arrowRowScanner = d.(*arrowRowScanner) + ars := d.(*arrowRowScanner) err := ars.makeColumnValuesContainers(ars, rowscanner.NewDelimiter(0, 0)) require.Nil(t, err) @@ -485,7 +485,7 @@ func TestArrowRowScanner(t *testing.T) { d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) - var ars *arrowRowScanner = d.(*arrowRowScanner) + ars := d.(*arrowRowScanner) assert.Nil(t, ars.rowValues) @@ -555,7 +555,7 @@ func TestArrowRowScanner(t *testing.T) { d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, nil) - var ars *arrowRowScanner = d.(*arrowRowScanner) + ars := d.(*arrowRowScanner) fbi := &fakeBatchIterator{ batches: []SparkArrowBatch{ @@ -593,7 +593,7 @@ func TestArrowRowScanner(t *testing.T) { d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, nil) - var ars *arrowRowScanner = d.(*arrowRowScanner) + ars := d.(*arrowRowScanner) fbi := &fakeBatchIterator{ batches: []SparkArrowBatch{ @@ -632,7 +632,7 @@ func TestArrowRowScanner(t *testing.T) { d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, nil) - var ars *arrowRowScanner = d.(*arrowRowScanner) + ars := d.(*arrowRowScanner) fbi := &fakeBatchIterator{ batches: []SparkArrowBatch{ @@ -673,7 +673,7 @@ func TestArrowRowScanner(t *testing.T) { d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, nil) - var ars *arrowRowScanner = d.(*arrowRowScanner) + ars := d.(*arrowRowScanner) fbi := &fakeBatchIterator{ batches: []SparkArrowBatch{ @@ -709,7 +709,7 @@ func TestArrowRowScanner(t *testing.T) { d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) - var ars *arrowRowScanner = d.(*arrowRowScanner) + ars := d.(*arrowRowScanner) fbi := &fakeBatchIterator{ batches: []SparkArrowBatch{ @@ -766,7 +766,7 @@ func TestArrowRowScanner(t *testing.T) { var scale int32 = 10 var precision int32 = 2 - var columns []*cli_service.TColumnDesc = []*cli_service.TColumnDesc{ + columns := []*cli_service.TColumnDesc{ { ColumnName: "array_col", TypeDesc: &cli_service.TTypeDesc{ @@ -858,7 +858,7 @@ func TestArrowRowScanner(t *testing.T) { d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) - var ars *arrowRowScanner = d.(*arrowRowScanner) + ars := d.(*arrowRowScanner) ars.UseArrowNativeComplexTypes = true ars.UseArrowNativeDecimal = true ars.UseArrowNativeIntervalTypes = true @@ -1202,7 +1202,7 @@ func TestArrowRowScanner(t *testing.T) { // verify that the returned values for the complex type // columns are valid json strings var foo []any - var s string = dest[10].(string) + s := dest[10].(string) err := json.Unmarshal([]byte(s), &foo) assert.Nil(t, err) diff --git a/internal/rows/arrowbased/batchloader.go b/internal/rows/arrowbased/batchloader.go index e12ea4e6..545aa6e7 100644 --- a/internal/rows/arrowbased/batchloader.go +++ b/internal/rows/arrowbased/batchloader.go @@ -35,8 +35,8 @@ func NewCloudIPCStreamIterator( cfg *config.Config, ) (IPCStreamIterator, dbsqlerr.DBError) { httpClient := http.DefaultClient - if cfg.UserConfig.CloudFetchConfig.HTTPClient != nil { - httpClient = cfg.UserConfig.CloudFetchConfig.HTTPClient + if cfg.HTTPClient != nil { + httpClient = cfg.HTTPClient } bi := &cloudIPCStreamIterator{ @@ -160,7 +160,7 @@ func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) { link.RowCount, ) - cancelCtx, cancelFn := context.WithCancel(bi.ctx) + cancelCtx, cancelFn := context.WithCancel(bi.ctx) //nolint:gosec // cancelFn stored in task and called on completion task := &cloudFetchDownloadTask{ ctx: cancelCtx, cancel: cancelFn, @@ -269,7 +269,7 @@ func (cft *cloudFetchDownloadTask) Run() { // Read all data into memory before closing buf, err := io.ReadAll(getReader(data, cft.useLz4Compression)) - data.Close() + data.Close() //nolint:errcheck,gosec // G104: close after reading data if err != nil { cft.resultChan <- cloudFetchDownloadTaskResult{data: nil, err: err} return diff --git a/internal/rows/arrowbased/batchloader_test.go b/internal/rows/arrowbased/batchloader_test.go index 99538bbc..5b230610 100644 --- a/internal/rows/arrowbased/batchloader_test.go +++ b/internal/rows/arrowbased/batchloader_test.go @@ -271,7 +271,7 @@ func TestCloudFetchIterator(t *testing.T) { cfg := config.WithDefaults() cfg.UseLz4Compression = false cfg.MaxDownloadThreads = 1 - cfg.UserConfig.CloudFetchConfig.HTTPClient = customHTTPClient + cfg.HTTPClient = customHTTPClient bi, err := NewCloudBatchIterator( context.Background(), @@ -309,7 +309,7 @@ func TestCloudFetchIterator(t *testing.T) { cfg.UseLz4Compression = false cfg.MaxDownloadThreads = 1 // Explicitly set HTTPClient to nil to verify fallback behavior - cfg.UserConfig.CloudFetchConfig.HTTPClient = nil + cfg.HTTPClient = nil bi, err := NewCloudBatchIterator( context.Background(), diff --git a/internal/rows/arrowbased/columnValues.go b/internal/rows/arrowbased/columnValues.go index 0b6fc7d8..47105095 100644 --- a/internal/rows/arrowbased/columnValues.go +++ b/internal/rows/arrowbased/columnValues.go @@ -53,7 +53,7 @@ func (rv *rowValues) SetColumnValues(columnIndex int, values arrow.ArrayData) er } func (rv *rowValues) IsNull(columnIndex int, rowNumber int64) bool { - var b bool = true + b := true if columnIndex < len(rv.columnValueHolders) { b = rv.columnValueHolders[columnIndex].IsNull(int(rowNumber - rv.Start())) } diff --git a/internal/rows/columnbased/columnRows.go b/internal/rows/columnbased/columnRows.go index cfaa0daa..b77f230b 100644 --- a/internal/rows/columnbased/columnRows.go +++ b/internal/rows/columnbased/columnRows.go @@ -37,7 +37,7 @@ func NewColumnRowScanner(schema *cli_service.TTableSchema, rowSet *cli_service.T logger = dbsqllog.Logger } - var location *time.Location = time.UTC + location := time.UTC if cfg != nil { if cfg.Location != nil { location = cfg.Location diff --git a/internal/rows/rows.go b/internal/rows/rows.go index 963a3ce1..7ee2db44 100644 --- a/internal/rows/rows.go +++ b/internal/rows/rows.go @@ -57,6 +57,11 @@ type rows struct { logger_ *dbsqllog.DBSQLLogger ctx context.Context + + // Telemetry tracking + telemetryUpdate func(chunkCount int, bytesDownloaded int64) + chunkCount int + bytesDownloaded int64 } var _ driver.Rows = (*rows)(nil) @@ -72,6 +77,7 @@ func NewRows( client cli_service.TCLIService, config *config.Config, directResults *cli_service.TSparkDirectResults, + telemetryUpdate func(chunkCount int, bytesDownloaded int64), ) (driver.Rows, dbsqlerr.DBError) { connId := driverctx.ConnIdFromContext(ctx) @@ -91,7 +97,7 @@ func NewRows( } var pageSize int64 = 10000 - var location *time.Location = time.UTC + location := time.UTC if config != nil { pageSize = int64(config.MaxRows) @@ -103,14 +109,17 @@ func NewRows( logger.Debug().Msgf("databricks: creating Rows, pageSize: %d, location: %v", pageSize, location) r := &rows{ - client: client, - opHandle: opHandle, - connId: connId, - correlationId: correlationId, - location: location, - config: config, - logger_: logger, - ctx: ctx, + client: client, + opHandle: opHandle, + connId: connId, + correlationId: correlationId, + location: location, + config: config, + logger_: logger, + ctx: ctx, + telemetryUpdate: telemetryUpdate, + chunkCount: 0, + bytesDownloaded: 0, } // if we already have results for the query do some additional initialization @@ -127,6 +136,17 @@ func NewRows( if err != nil { return r, err } + + r.chunkCount++ + if directResults.ResultSet != nil && directResults.ResultSet.Results != nil && directResults.ResultSet.Results.ArrowBatches != nil { + for _, batch := range directResults.ResultSet.Results.ArrowBatches { + r.bytesDownloaded += int64(len(batch.Batch)) + } + } + + if r.telemetryUpdate != nil { + r.telemetryUpdate(r.chunkCount, r.bytesDownloaded) + } } var d rowscanner.Delimiter @@ -433,7 +453,7 @@ func (r *rows) getResultSetSchema() (*cli_service.TTableSchema, dbsqlerr.DBError // fetchResultPage will fetch the result page containing the next row, if necessary func (r *rows) fetchResultPage() error { - var err dbsqlerr.DBError = isValidRows(r) + err := isValidRows(r) if err != nil { return err } @@ -458,6 +478,19 @@ func (r *rows) fetchResultPage() error { return err1 } + r.chunkCount++ + if fetchResult != nil && fetchResult.Results != nil { + if fetchResult.Results.ArrowBatches != nil { + for _, batch := range fetchResult.Results.ArrowBatches { + r.bytesDownloaded += int64(len(batch.Batch)) + } + } + } + + if r.telemetryUpdate != nil { + r.telemetryUpdate(r.chunkCount, r.bytesDownloaded) + } + err1 = r.makeRowScanner(fetchResult) if err1 != nil { return err1 diff --git a/internal/rows/rows_test.go b/internal/rows/rows_test.go index fa6c913a..bb8ac196 100644 --- a/internal/rows/rows_test.go +++ b/internal/rows/rows_test.go @@ -421,11 +421,11 @@ func TestColumnsWithDirectResults(t *testing.T) { ctx := driverctx.NewContextWithConnId(context.Background(), "connId") ctx = driverctx.NewContextWithCorrelationId(ctx, "corrId") - d, err := NewRows(ctx, nil, client, nil, nil) + d, err := NewRows(ctx, nil, client, nil, nil, nil) assert.Nil(t, err) rowSet := d.(*rows) - defer rowSet.Close() + defer rowSet.Close() //nolint:errcheck req2 := &cli_service.TGetResultSetMetadataReq{} metadata, _ := client.GetResultSetMetadata(context.Background(), req2) @@ -720,7 +720,7 @@ func TestRowsCloseOptimization(t *testing.T) { ctx := driverctx.NewContextWithConnId(context.Background(), "connId") ctx = driverctx.NewContextWithCorrelationId(ctx, "corrId") opHandle := &cli_service.TOperationHandle{OperationId: &cli_service.THandleIdentifier{GUID: []byte{'f', 'o'}}} - rowSet, _ := NewRows(ctx, opHandle, client, nil, nil) + rowSet, _ := NewRows(ctx, opHandle, client, nil, nil, nil) // rowSet has no direct results calling Close should result in call to client to close operation err := rowSet.Close() @@ -733,7 +733,7 @@ func TestRowsCloseOptimization(t *testing.T) { ResultSet: &cli_service.TFetchResultsResp{Results: &cli_service.TRowSet{Columns: []*cli_service.TColumn{}}}, } closeCount = 0 - rowSet, _ = NewRows(ctx, opHandle, client, nil, directResults) + rowSet, _ = NewRows(ctx, opHandle, client, nil, directResults, nil) err = rowSet.Close() assert.Nil(t, err, "rows.Close should not throw an error") assert.Equal(t, 1, closeCount) @@ -746,7 +746,7 @@ func TestRowsCloseOptimization(t *testing.T) { ResultSetMetadata: &cli_service.TGetResultSetMetadataResp{Schema: &cli_service.TTableSchema{}}, ResultSet: &cli_service.TFetchResultsResp{Results: &cli_service.TRowSet{Columns: []*cli_service.TColumn{}}}, } - rowSet, _ = NewRows(ctx, opHandle, client, nil, directResults) + rowSet, _ = NewRows(ctx, opHandle, client, nil, directResults, nil) err = rowSet.Close() assert.Nil(t, err, "rows.Close should not throw an error") assert.Equal(t, 0, closeCount) @@ -816,7 +816,7 @@ func TestGetArrowBatches(t *testing.T) { client := getSimpleClient([]cli_service.TFetchResultsResp{fetchResp1, fetchResp2}) cfg := config.WithDefaults() - rows, err := NewRows(ctx, nil, client, cfg, executeStatementResp.DirectResults) + rows, err := NewRows(ctx, nil, client, cfg, executeStatementResp.DirectResults, nil) assert.Nil(t, err) rows2, ok := rows.(dbsqlrows.Rows) @@ -889,7 +889,7 @@ func TestGetArrowBatches(t *testing.T) { client := getSimpleClient([]cli_service.TFetchResultsResp{fetchResp1, fetchResp2, fetchResp3}) cfg := config.WithDefaults() - rows, err := NewRows(ctx, nil, client, cfg, nil) + rows, err := NewRows(ctx, nil, client, cfg, nil, nil) assert.Nil(t, err) rows2, ok := rows.(dbsqlrows.Rows) @@ -950,7 +950,7 @@ func TestGetArrowBatches(t *testing.T) { client := getSimpleClient([]cli_service.TFetchResultsResp{fetchResp1}) cfg := config.WithDefaults() - rows, err := NewRows(ctx, nil, client, cfg, nil) + rows, err := NewRows(ctx, nil, client, cfg, nil, nil) assert.Nil(t, err) rows2, ok := rows.(dbsqlrows.Rows) @@ -977,7 +977,7 @@ func TestGetArrowBatches(t *testing.T) { client := getSimpleClient([]cli_service.TFetchResultsResp{}) cfg := config.WithDefaults() - rows, err := NewRows(ctx, nil, client, cfg, executeStatementResp.DirectResults) + rows, err := NewRows(ctx, nil, client, cfg, executeStatementResp.DirectResults, nil) assert.Nil(t, err) rows2, ok := rows.(dbsqlrows.Rows) @@ -1381,17 +1381,18 @@ func getRowsTestSimpleClient(getMetadataCount, fetchResultsCount *int) cli_servi fetchResults := func(ctx context.Context, req *cli_service.TFetchResultsReq) (_r *cli_service.TFetchResultsResp, _err error) { *fetchResultsCount++ - if req.Orientation == cli_service.TFetchOrientation_FETCH_NEXT { + switch req.Orientation { + case cli_service.TFetchOrientation_FETCH_NEXT: if pageIndex+1 >= len(pages) { return nil, errors.New("can't fetch past end of result set") } pageIndex++ - } else if req.Orientation == cli_service.TFetchOrientation_FETCH_PRIOR { + case cli_service.TFetchOrientation_FETCH_PRIOR: if pageIndex-1 < 0 { return nil, errors.New("can't fetch prior to start of result set") } pageIndex-- - } else { + default: return nil, errors.New("invalid fetch results orientation") } @@ -1556,7 +1557,7 @@ func TestFetchResultPage_PropagatesGetNextPageError(t *testing.T) { executeStatementResp := cli_service.TExecuteStatementResp{} cfg := config.WithDefaults() - rows, _ := NewRows(ctx, nil, client, cfg, executeStatementResp.DirectResults) + rows, _ := NewRows(ctx, nil, client, cfg, executeStatementResp.DirectResults, nil) // Call Next and ensure it propagates the error from getNextPage actualErr := rows.Next(nil) diff --git a/internal/rows/rowscanner/resultPageIterator.go b/internal/rows/rowscanner/resultPageIterator.go index 144d7447..9cde0950 100644 --- a/internal/rows/rowscanner/resultPageIterator.go +++ b/internal/rows/rowscanner/resultPageIterator.go @@ -125,7 +125,7 @@ func (rpf *resultPageIterator) HasNext() bool { if rpf.nextResultPage == nil { nrp, err := rpf.getNextPage() if err != nil { - rpf.Close() + rpf.Close() //nolint:errcheck,gosec // G104: close in error path rpf.isFinished = true rpf.err = err return false @@ -134,7 +134,7 @@ func (rpf *resultPageIterator) HasNext() bool { rpf.err = nil rpf.nextResultPage = nrp if !nrp.GetHasMoreRows() { - rpf.Close() + rpf.Close() //nolint:errcheck,gosec // G104: close in error path } } @@ -285,17 +285,18 @@ func CountRows(rowSet *cli_service.TRowSet) int64 { // Check if trying to fetch in the specified direction creates an error condition. func (rpf *resultPageIterator) checkDirectionValid(direction Direction) error { - if direction == DirBack { + switch direction { + case DirBack: // can't fetch rows previous to the start if rpf.Start() == 0 { return dbsqlerrint.NewDriverError(rpf.ctx, ErrRowsFetchPriorToStart, nil) } - } else if direction == DirForward { + case DirForward: // can't fetch past the end of the query results if rpf.isFinished { return io.EOF } - } else { + default: rpf.logger.Error().Msgf(errRowsUnandledFetchDirection(direction.String())) return dbsqlerrint.NewDriverError(rpf.ctx, errRowsUnandledFetchDirection(direction.String()), nil) } diff --git a/internal/rows/rowscanner/rowScanner.go b/internal/rows/rowscanner/rowScanner.go index c83cc952..ebb25bf6 100644 --- a/internal/rows/rowscanner/rowScanner.go +++ b/internal/rows/rowscanner/rowScanner.go @@ -46,7 +46,7 @@ func IsNull(nulls []byte, position int64) bool { index := position / 8 if int64(len(nulls)) > index { b := nulls[index] - return (b & (1 << (uint)(position%8))) != 0 + return (b & (1 << (uint)(position%8))) != 0 //nolint:gosec // position%8 is always 0-7 } return false } diff --git a/telemetry/DESIGN.md b/telemetry/DESIGN.md index 3e742b67..5996ed2e 100644 --- a/telemetry/DESIGN.md +++ b/telemetry/DESIGN.md @@ -1,2257 +1,106 @@ -# Databricks SQL Go Driver: Telemetry Design +# Telemetry Design -## Executive Summary - -This document outlines a **telemetry design** for the Databricks SQL Go driver that collects usage metrics and exports them to the Databricks telemetry service. The design leverages Go's `context.Context` and middleware patterns to instrument driver operations without impacting performance. - -**Important Note:** Telemetry is **disabled by default** and requires explicit opt-in. A gradual rollout strategy will be used to ensure reliability and user control. - -**Key Objectives:** -- Collect driver usage metrics and performance data -- Export aggregated metrics to Databricks telemetry service -- Maintain server-side feature flag control -- Ensure zero impact on driver performance and reliability -- Follow Go best practices and idiomatic patterns - -**Design Principles:** -- **Opt-in first**: User explicit consent required, disabled by default -- **Non-blocking**: All operations async using goroutines -- **Privacy-first**: No PII or query data collected -- **Server-controlled**: Feature flag support for gradual rollout -- **Fail-safe**: All telemetry errors swallowed silently -- **Idiomatic Go**: Use standard library patterns and interfaces - -**Production Requirements** (from JDBC driver experience): -- **Feature flag caching**: Per-host caching to avoid rate limiting -- **Circuit breaker**: Protect against telemetry endpoint failures -- **Exception swallowing**: All telemetry errors caught with minimal logging -- **Per-host telemetry client**: One client per host to prevent rate limiting -- **Graceful shutdown**: Proper cleanup with reference counting -- **Smart exception flushing**: Only flush terminal errors immediately - ---- - -## Table of Contents - -1. [Background & Motivation](#1-background--motivation) -2. [Architecture Overview](#2-architecture-overview) -3. [Core Components](#3-core-components) - - 3.1 [featureFlagCache (Per-Host)](#31-featureflagcache-per-host) - - 3.2 [clientManager (Per-Host)](#32-clientmanager-per-host) - - 3.3 [circuitBreaker](#33-circuitbreaker) - - 3.4 [Telemetry Interceptor](#34-telemetry-interceptor) - - 3.5 [metricsAggregator](#35-metricsaggregator) - - 3.6 [telemetryExporter](#36-telemetryexporter) -4. [Data Collection](#4-data-collection) -5. [Export Mechanism](#5-export-mechanism) -6. [Configuration](#6-configuration) - - 6.1 [Configuration Structure](#61-configuration-structure) - - 6.2 [Configuration from DSN](#62-configuration-from-dsn) - - 6.3 [Feature Flag Integration](#63-feature-flag-integration) - - 6.4 [Opt-In Control & Priority](#64-opt-in-control--priority) -7. [Privacy & Compliance](#7-privacy--compliance) -8. [Error Handling](#8-error-handling) -9. [Graceful Shutdown](#9-graceful-shutdown) -10. [Testing Strategy](#10-testing-strategy) -11. [Partial Launch Strategy](#11-partial-launch-strategy) -12. [Implementation Checklist](#12-implementation-checklist) -13. [References](#13-references) - ---- - -## 1. Background & Motivation - -### 1.1 Current State - -The Databricks SQL Go driver (`databricks-sql-go`) provides: -- ✅ Standard `database/sql/driver` interface implementation -- ✅ Thrift-based communication with Databricks clusters -- ✅ Connection pooling via `database/sql` package -- ✅ Context-based cancellation and timeouts -- ✅ Structured logging - -### 1.2 Requirements - -Product teams need telemetry data to: -- Understand driver feature adoption (CloudFetch, LZ4 compression, etc.) -- Identify performance bottlenecks -- Track error rates and failure modes -- Make data-driven decisions for driver improvements - -### 1.3 Design Goals - -**Instrumentation Strategy**: -- ✅ Minimal code changes to existing driver -- ✅ Use Go idiomatic patterns (interfaces, middleware, context) -- ✅ Zero performance impact when telemetry disabled -- ✅ < 1% overhead when telemetry enabled -- ✅ No risk to driver reliability - ---- - -## 2. Architecture Overview - -### 2.1 High-Level Architecture - -```mermaid -graph TB - A[Driver Operations] -->|context.Context| B[Telemetry Interceptor] - B -->|Collect Metrics| C[metricsAggregator] - C -->|Batch & Buffer| D[clientManager] - D -->|Get Per-Host Client| E[telemetryClient per Host] - E -->|Check Circuit Breaker| F[circuitBreaker] - F -->|HTTP POST| G[telemetryExporter] - G --> H[Databricks Service] - H --> I[Lumberjack] - - J[featureFlagCache per Host] -.->|Enable/Disable| B - K[Connection Open] -->|Increment RefCount| D - K -->|Increment RefCount| J - L[Connection Close] -->|Decrement RefCount| D - L -->|Decrement RefCount| J - - style B fill:#e1f5fe - style C fill:#e1f5fe - style D fill:#ffe0b2 - style E fill:#ffe0b2 - style F fill:#ffccbc - style J fill:#c8e6c9 -``` - -**Key Components:** -1. **Telemetry Interceptor** (new): Wraps driver operations to collect metrics -2. **featureFlagCache** (new): Per-host caching of feature flags with reference counting -3. **clientManager** (new): Manages one telemetry client per host with reference counting -4. **circuitBreaker** (new): Protects against failing telemetry endpoint -5. **metricsAggregator** (new): Aggregates by statement, batches events -6. **telemetryExporter** (new): Exports to Databricks service - -### 2.2 Data Flow - -```mermaid -sequenceDiagram - participant App as Application - participant Driver as dbsql.conn - participant Interceptor as Telemetry Interceptor - participant Agg as metricsAggregator - participant Exp as telemetryExporter - participant Service as Databricks Service - - App->>Driver: QueryContext(ctx, query) - Driver->>Interceptor: Before Execute - Interceptor->>Interceptor: Start timer - - Driver->>Driver: Execute operation - Driver->>Interceptor: After Execute (with result/error) - - Interceptor->>Agg: RecordMetric(ctx, metric) - Agg->>Agg: Aggregate by statement_id - - alt Batch threshold reached - Agg->>Exp: Flush(batch) - Exp->>Service: POST /telemetry-ext - end -``` - ---- - -## 3. Core Components - -### 3.1 featureFlagCache (Per-Host) - -**Purpose**: Cache feature flag values at the host level to avoid repeated API calls and rate limiting. - -**Location**: `telemetry/featureflag.go` - -#### Rationale -- **Per-host caching**: Feature flags cached by host to prevent rate limiting -- **Multi-flag support**: Fetches all flags in a single request for efficiency -- **Reference counting**: Tracks number of connections per host for proper cleanup -- **Automatic expiration**: Refreshes cached flags after TTL expires (15 minutes) -- **Thread-safe**: Uses sync.RWMutex for concurrent access -- **Synchronous fetch**: Blocks on cache miss (see Section 6.3 for behavior details) -- **Thundering herd protection**: Only one fetch per host at a time - -#### Interface - -```go -package telemetry - -import ( - "context" - "net/http" - "sync" - "time" -) - -// featureFlagCache manages feature flag state per host with reference counting. -// This prevents rate limiting by caching feature flag responses. -type featureFlagCache struct { - mu sync.RWMutex - contexts map[string]*featureFlagContext -} - -// featureFlagContext holds feature flag state and reference count for a host. -type featureFlagContext struct { - enabled *bool - lastFetched time.Time - refCount int - cacheDuration time.Duration -} - -var ( - flagCacheOnce sync.Once - flagCacheInstance *featureFlagCache -) - -// getFeatureFlagCache returns the singleton instance. -func getFeatureFlagCache() *featureFlagCache { - flagCacheOnce.Do(func() { - flagCacheInstance = &featureFlagCache{ - contexts: make(map[string]*featureFlagContext), - } - }) - return flagCacheInstance -} - -// getOrCreateContext gets or creates a feature flag context for the host. -// Increments reference count. -func (c *featureFlagCache) getOrCreateContext(host string) *featureFlagContext { - c.mu.Lock() - defer c.mu.Unlock() - - ctx, exists := c.contexts[host] - if !exists { - ctx = &featureFlagContext{ - cacheDuration: 15 * time.Minute, - } - c.contexts[host] = ctx - } - ctx.refCount++ - return ctx -} - -// releaseContext decrements reference count for the host. -// Removes context when ref count reaches zero. -func (c *featureFlagCache) releaseContext(host string) { - c.mu.Lock() - defer c.mu.Unlock() - - if ctx, exists := c.contexts[host]; exists { - ctx.refCount-- - if ctx.refCount <= 0 { - delete(c.contexts, host) - } - } -} - -// isTelemetryEnabled checks if telemetry is enabled for the host. -// Uses cached value if available and not expired. -func (c *featureFlagCache) isTelemetryEnabled(ctx context.Context, host string, httpClient *http.Client) (bool, error) { - c.mu.RLock() - flagCtx, exists := c.contexts[host] - c.mu.RUnlock() - - if !exists { - return false, nil - } - - // Check if cache is valid - if flagCtx.enabled != nil && time.Since(flagCtx.lastFetched) < flagCtx.cacheDuration { - return *flagCtx.enabled, nil - } - - // Fetch fresh value - enabled, err := fetchFeatureFlag(ctx, host, httpClient) - if err != nil { - // Return cached value on error, or false if no cache - if flagCtx.enabled != nil { - return *flagCtx.enabled, nil - } - return false, err - } - - // Update cache - c.mu.Lock() - flagCtx.enabled = &enabled - flagCtx.lastFetched = time.Now() - c.mu.Unlock() - - return enabled, nil -} - -// isExpired returns true if the cache has expired. -func (c *featureFlagContext) isExpired() bool { - return c.enabled == nil || time.Since(c.lastFetched) > c.cacheDuration -} -``` - ---- - -### 3.2 clientManager (Per-Host) - -**Purpose**: Manage one telemetry client per host to prevent rate limiting from concurrent connections. - -**Location**: `telemetry/manager.go` - -#### Rationale -- **One client per host**: Large customers open many parallel connections to the same host -- **Prevents rate limiting**: Shared client batches events from all connections -- **Reference counting**: Tracks active connections, only closes client when last connection closes -- **Thread-safe**: Safe for concurrent access from multiple goroutines - -#### Interface - -```go -package telemetry - -import ( - "net/http" - "sync" -) - -// clientManager manages one telemetry client per host. -// Prevents rate limiting by sharing clients across connections. -type clientManager struct { - mu sync.RWMutex - clients map[string]*clientHolder -} - -// clientHolder holds a telemetry client and its reference count. -type clientHolder struct { - client *telemetryClient - refCount int -} - -var ( - managerOnce sync.Once - managerInstance *clientManager -) - -// getClientManager returns the singleton instance. -func getClientManager() *clientManager { - managerOnce.Do(func() { - managerInstance = &clientManager{ - clients: make(map[string]*clientHolder), - } - }) - return managerInstance -} - -// getOrCreateClient gets or creates a telemetry client for the host. -// Increments reference count. -func (m *clientManager) getOrCreateClient(host string, httpClient *http.Client, cfg *config) *telemetryClient { - m.mu.Lock() - defer m.mu.Unlock() - - holder, exists := m.clients[host] - if !exists { - holder = &clientHolder{ - client: newTelemetryClient(host, httpClient, cfg), - } - m.clients[host] = holder - holder.client.start() // Start background flush goroutine - } - holder.refCount++ - return holder.client -} - -// releaseClient decrements reference count for the host. -// Closes and removes client when ref count reaches zero. -func (m *clientManager) releaseClient(host string) error { - m.mu.Lock() - holder, exists := m.clients[host] - if !exists { - m.mu.Unlock() - return nil - } - - holder.refCount-- - if holder.refCount <= 0 { - delete(m.clients, host) - m.mu.Unlock() - return holder.client.close() // Close and flush - } - - m.mu.Unlock() - return nil -} -``` - ---- - -### 3.3 circuitBreaker - -**Purpose**: Implement circuit breaker pattern to protect against failing telemetry endpoint using a sliding window and failure rate percentage algorithm (matching JDBC's Resilience4j implementation). - -**Location**: `telemetry/circuitbreaker.go` - -#### Rationale -- **Endpoint protection**: The telemetry endpoint itself may fail or become unavailable -- **Not just rate limiting**: Protects against 5xx errors, timeouts, network failures -- **Resource efficiency**: Prevents wasting resources on a failing endpoint -- **Auto-recovery**: Automatically detects when endpoint becomes healthy again -- **JDBC alignment**: Uses sliding window with failure rate percentage, matching JDBC driver behavior exactly - -#### Algorithm: Sliding Window with Failure Rate -The circuit breaker tracks recent calls in a **sliding window** (ring buffer) and calculates the **failure rate percentage**: -- Tracks the last N calls (default: 30) -- Opens circuit when failure rate >= threshold (default: 50%) -- Requires minimum calls before evaluation (default: 20) -- Uses percentage-based evaluation instead of consecutive failures - -**Example**: With 30 calls in window, if 15 or more fail (50%), circuit opens. This is more robust than consecutive-failure counting as it considers overall reliability. - -#### States -1. **Closed**: Normal operation, requests pass through -2. **Open**: After failure rate exceeds threshold, all requests rejected immediately (drop events) -3. **Half-Open**: After wait duration, allows test requests to check if endpoint recovered - -#### Configuration (matching JDBC defaults) -- **failureRateThreshold**: 50% - Opens circuit if failure rate >= 50% -- **minimumNumberOfCalls**: 20 - Minimum calls before evaluating failure rate -- **slidingWindowSize**: 30 - Track last 30 calls in sliding window -- **waitDurationInOpenState**: 30s - Wait before transitioning to half-open -- **permittedCallsInHalfOpen**: 3 - Test with 3 successful calls before closing - -#### Interface - -```go -package telemetry - -import ( - "context" - "errors" - "sync" - "sync/atomic" - "time" -) - -// circuitState represents the state of the circuit breaker. -type circuitState int32 - -const ( - stateClosed circuitState = iota - stateOpen - stateHalfOpen -) - -// callResult represents the result of a call (success or failure). -type callResult bool - -const ( - callSuccess callResult = true - callFailure callResult = false -) - -// circuitBreaker implements the circuit breaker pattern. -// It protects against failing telemetry endpoints by tracking failures -// using a sliding window and failure rate percentage. -type circuitBreaker struct { - mu sync.RWMutex - - state atomic.Int32 // circuitState - lastStateTime time.Time - - // Sliding window for tracking calls - window []callResult - windowIndex int - windowFilled bool - totalCalls int - failureCount int - - // Half-open state tracking - halfOpenSuccesses int - - config circuitBreakerConfig -} - -// circuitBreakerConfig holds circuit breaker configuration. -type circuitBreakerConfig struct { - failureRateThreshold int // Open if failure rate >= this percentage (0-100) - minimumNumberOfCalls int // Minimum calls before evaluating failure rate - slidingWindowSize int // Number of recent calls to track - waitDurationInOpenState time.Duration // Wait before transitioning to half-open - permittedCallsInHalfOpen int // Number of test calls in half-open state -} - -// defaultCircuitBreakerConfig returns default configuration matching JDBC. -func defaultCircuitBreakerConfig() circuitBreakerConfig { - return circuitBreakerConfig{ - failureRateThreshold: 50, // 50% failure rate - minimumNumberOfCalls: 20, // Minimum sample size - slidingWindowSize: 30, // Keep recent 30 calls - waitDurationInOpenState: 30 * time.Second, - permittedCallsInHalfOpen: 3, // Test with 3 calls - } -} - -// newCircuitBreaker creates a new circuit breaker. -func newCircuitBreaker(cfg circuitBreakerConfig) *circuitBreaker { - cb := &circuitBreaker{ - config: cfg, - lastStateTime: time.Now(), - window: make([]callResult, cfg.slidingWindowSize), - } - cb.state.Store(int32(stateClosed)) - return cb -} - -// ErrCircuitOpen is returned when circuit is open. -var ErrCircuitOpen = errors.New("circuit breaker is open") - -// execute executes the function if circuit allows. -func (cb *circuitBreaker) execute(ctx context.Context, fn func() error) error { - state := circuitState(cb.state.Load()) - - switch state { - case stateOpen: - // Check if wait duration has passed - cb.mu.RLock() - shouldRetry := time.Since(cb.lastStateTime) > cb.config.waitDurationInOpenState - cb.mu.RUnlock() - - if shouldRetry { - // Transition to half-open - cb.setState(stateHalfOpen) - return cb.tryExecute(ctx, fn) - } - return ErrCircuitOpen - - case stateHalfOpen: - return cb.tryExecute(ctx, fn) - - case stateClosed: - return cb.tryExecute(ctx, fn) - } - - return nil -} - -// tryExecute attempts to execute the function and updates state. -func (cb *circuitBreaker) tryExecute(ctx context.Context, fn func() error) error { - err := fn() - - if err != nil { - cb.recordCall(callFailure) - return err - } - - cb.recordCall(callSuccess) - return nil -} - -// recordCall records a call result in the sliding window and evaluates state transitions. -func (cb *circuitBreaker) recordCall(result callResult) { - cb.mu.Lock() - defer cb.mu.Unlock() - - state := circuitState(cb.state.Load()) - - // Handle half-open state specially - if state == stateHalfOpen { - if result == callFailure { - // Any failure in half-open immediately reopens circuit - cb.resetWindowUnlocked() - cb.setStateUnlocked(stateOpen) - return - } - - cb.halfOpenSuccesses++ - if cb.halfOpenSuccesses >= cb.config.permittedCallsInHalfOpen { - // Enough successes to close circuit - cb.resetWindowUnlocked() - cb.setStateUnlocked(stateClosed) - } - return - } - - // Record in sliding window - // Remove old value from count if window is full - if cb.windowFilled && cb.window[cb.windowIndex] == callFailure { - cb.failureCount-- - } - - // Add new value - cb.window[cb.windowIndex] = result - if result == callFailure { - cb.failureCount++ - } - - // Move to next position - cb.windowIndex = (cb.windowIndex + 1) % cb.config.slidingWindowSize - if cb.windowIndex == 0 { - cb.windowFilled = true - } - - cb.totalCalls++ - - // Evaluate if we should open the circuit - if state == stateClosed { - cb.evaluateStateUnlocked() - } -} - -// evaluateStateUnlocked checks if the circuit should open based on failure rate. -// Caller must hold cb.mu lock. -func (cb *circuitBreaker) evaluateStateUnlocked() { - // Need minimum number of calls before evaluating - windowSize := cb.totalCalls - if cb.windowFilled { - windowSize = cb.config.slidingWindowSize - } - - if windowSize < cb.config.minimumNumberOfCalls { - return - } - - // Calculate failure rate - failureRate := (cb.failureCount * 100) / windowSize - - if failureRate >= cb.config.failureRateThreshold { - cb.setStateUnlocked(stateOpen) - } -} - -// resetWindowUnlocked clears the sliding window. -// Caller must hold cb.mu lock. -func (cb *circuitBreaker) resetWindowUnlocked() { - cb.windowIndex = 0 - cb.windowFilled = false - cb.totalCalls = 0 - cb.failureCount = 0 - cb.halfOpenSuccesses = 0 -} - -// setState transitions to a new state. -func (cb *circuitBreaker) setState(newState circuitState) { - cb.mu.Lock() - defer cb.mu.Unlock() - cb.setStateUnlocked(newState) -} - -// setStateUnlocked transitions to a new state without locking. -// Caller must hold cb.mu lock. -func (cb *circuitBreaker) setStateUnlocked(newState circuitState) { - oldState := circuitState(cb.state.Load()) - if oldState == newState { - return - } - - cb.state.Store(int32(newState)) - cb.lastStateTime = time.Now() - - // Log state transition at DEBUG level - // logger.Debug().Msgf("circuit breaker: %v -> %v", oldState, newState) -} - -// circuitBreakerManager manages circuit breakers per host. -// Each host gets its own circuit breaker to provide isolation. -type circuitBreakerManager struct { - mu sync.RWMutex - breakers map[string]*circuitBreaker -} - -var ( - breakerManagerOnce sync.Once - breakerManagerInstance *circuitBreakerManager -) - -// getCircuitBreakerManager returns the singleton instance. -func getCircuitBreakerManager() *circuitBreakerManager { - breakerManagerOnce.Do(func() { - breakerManagerInstance = &circuitBreakerManager{ - breakers: make(map[string]*circuitBreaker), - } - }) - return breakerManagerInstance -} - -// getCircuitBreaker gets or creates a circuit breaker for the host. -// Thread-safe for concurrent access. -func (m *circuitBreakerManager) getCircuitBreaker(host string) *circuitBreaker { - m.mu.RLock() - cb, exists := m.breakers[host] - m.mu.RUnlock() - - if exists { - return cb - } - - m.mu.Lock() - defer m.mu.Unlock() - - // Double-check after acquiring write lock - if cb, exists = m.breakers[host]; exists { - return cb - } - - cb = newCircuitBreaker(defaultCircuitBreakerConfig()) - m.breakers[host] = cb - return cb -} -``` - ---- - -### 3.4 Telemetry Interceptor - -**Purpose**: Intercept driver operations to collect metrics without modifying core driver logic. - -**Location**: `telemetry/interceptor.go` - -#### Approach - -Use Go's context and function wrapping patterns to intercept operations: - -```go -package telemetry - -import ( - "context" - "time" -) - -// interceptor wraps driver operations to collect metrics. -type interceptor struct { - aggregator *metricsAggregator - enabled bool -} - -// metricContext holds metric collection state in context. -type metricContext struct { - statementID string - startTime time.Time - tags map[string]interface{} -} - -type contextKey int - -const metricContextKey contextKey = 0 - -// withMetricContext adds metric context to the context. -func withMetricContext(ctx context.Context, mc *metricContext) context.Context { - return context.WithValue(ctx, metricContextKey, mc) -} - -// getMetricContext retrieves metric context from the context. -func getMetricContext(ctx context.Context) *metricContext { - if mc, ok := ctx.Value(metricContextKey).(*metricContext); ok { - return mc - } - return nil -} - -// beforeExecute is called before statement execution. -func (i *interceptor) beforeExecute(ctx context.Context, statementID string) context.Context { - if !i.enabled { - return ctx - } - - mc := &metricContext{ - statementID: statementID, - startTime: time.Now(), - tags: make(map[string]interface{}), - } - - return withMetricContext(ctx, mc) -} - -// afterExecute is called after statement execution. -func (i *interceptor) afterExecute(ctx context.Context, err error) { - if !i.enabled { - return - } - - mc := getMetricContext(ctx) - if mc == nil { - return - } - - // Swallow all panics - defer func() { - if r := recover(); r != nil { - // Log at trace level only - } - }() - - metric := &telemetryMetric{ - metricType: "statement", - timestamp: mc.startTime, - statementID: mc.statementID, - latencyMs: time.Since(mc.startTime).Milliseconds(), - tags: mc.tags, - } - - if err != nil { - metric.errorType = classifyError(err) - } - - // Non-blocking send - i.aggregator.recordMetric(ctx, metric) -} - -// addTag adds a tag to the current metric context. -func (i *interceptor) addTag(ctx context.Context, key string, value interface{}) { - if !i.enabled { - return - } - - mc := getMetricContext(ctx) - if mc != nil { - mc.tags[key] = value - } -} -``` - -**Integration points in driver**: - -```go -// In statement.go -func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { - // Wrap context with telemetry - if s.conn.telemetry != nil { - ctx = s.conn.telemetry.beforeExecute(ctx, generateStatementID()) - defer func() { - s.conn.telemetry.afterExecute(ctx, err) - }() - } - - // Existing implementation - // ... -} -``` - ---- - -### 3.5 metricsAggregator - -**Purpose**: Aggregate metrics by statement and batch for efficient export. - -**Location**: `telemetry/aggregator.go` - -#### Interface - -```go -package telemetry - -import ( - "context" - "sync" - "time" -) - -// metricsAggregator aggregates metrics by statement and batches for export. -type metricsAggregator struct { - mu sync.RWMutex - - statements map[string]*statementMetrics - batch []*telemetryMetric - exporter *telemetryExporter - - batchSize int - flushInterval time.Duration - flushTimer *time.Timer - stopCh chan struct{} -} - -// statementMetrics holds aggregated metrics for a statement. -type statementMetrics struct { - statementID string - sessionID string - totalLatency time.Duration - chunkCount int - bytesDownloaded int64 - pollCount int - errors []error - tags map[string]interface{} -} - -// newMetricsAggregator creates a new metrics aggregator. -func newMetricsAggregator(exporter *telemetryExporter, cfg *config) *metricsAggregator { - agg := &metricsAggregator{ - statements: make(map[string]*statementMetrics), - batch: make([]*telemetryMetric, 0, cfg.batchSize), - exporter: exporter, - batchSize: cfg.batchSize, - flushInterval: cfg.flushInterval, - stopCh: make(chan struct{}), - } - - // Start background flush timer - go agg.flushLoop() - - return agg -} - -// recordMetric records a metric for aggregation. -func (agg *metricsAggregator) recordMetric(ctx context.Context, metric *telemetryMetric) { - // Swallow all errors - defer func() { - if r := recover(); r != nil { - // Log at trace level only - } - }() - - agg.mu.Lock() - defer agg.mu.Unlock() - - switch metric.metricType { - case "connection": - // Emit immediately - agg.batch = append(agg.batch, metric) - if len(agg.batch) >= agg.batchSize { - agg.flushUnlocked(ctx) - } - - case "statement": - // Aggregate by statement ID - stmt, exists := agg.statements[metric.statementID] - if !exists { - stmt = &statementMetrics{ - statementID: metric.statementID, - tags: make(map[string]interface{}), - } - agg.statements[metric.statementID] = stmt - } - - // Update aggregated values - stmt.totalLatency += time.Duration(metric.latencyMs) * time.Millisecond - if chunkCount, ok := metric.tags["chunk_count"].(int); ok { - stmt.chunkCount += chunkCount - } - if bytes, ok := metric.tags["bytes_downloaded"].(int64); ok { - stmt.bytesDownloaded += bytes - } - - // Merge tags - for k, v := range metric.tags { - stmt.tags[k] = v - } - - case "error": - // Check if terminal - if isTerminalError(metric.errorType) { - // Flush immediately - agg.batch = append(agg.batch, metric) - agg.flushUnlocked(ctx) - } else { - // Buffer with statement - if stmt, exists := agg.statements[metric.statementID]; exists { - stmt.errors = append(stmt.errors, errors.New(metric.errorType)) - } - } - } -} - -// completeStatement marks a statement as complete and emits aggregated metric. -func (agg *metricsAggregator) completeStatement(ctx context.Context, statementID string, failed bool) { - defer func() { - if r := recover(); r != nil { - // Log at trace level only - } - }() - - agg.mu.Lock() - defer agg.mu.Unlock() - - stmt, exists := agg.statements[statementID] - if !exists { - return - } - delete(agg.statements, statementID) - - // Create aggregated metric - metric := &telemetryMetric{ - metricType: "statement", - timestamp: time.Now(), - statementID: stmt.statementID, - sessionID: stmt.sessionID, - latencyMs: stmt.totalLatency.Milliseconds(), - tags: stmt.tags, - } - - // Add aggregated counts - metric.tags["chunk_count"] = stmt.chunkCount - metric.tags["bytes_downloaded"] = stmt.bytesDownloaded - metric.tags["poll_count"] = stmt.pollCount - - agg.batch = append(agg.batch, metric) - - // Emit errors if statement failed - if failed && len(stmt.errors) > 0 { - for _, err := range stmt.errors { - errorMetric := &telemetryMetric{ - metricType: "error", - timestamp: time.Now(), - statementID: statementID, - errorType: err.Error(), - } - agg.batch = append(agg.batch, errorMetric) - } - } - - // Flush if batch full - if len(agg.batch) >= agg.batchSize { - agg.flushUnlocked(ctx) - } -} - -// flushLoop runs periodic flush in background. -func (agg *metricsAggregator) flushLoop() { - ticker := time.NewTicker(agg.flushInterval) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - agg.flush(context.Background()) - case <-agg.stopCh: - return - } - } -} - -// flush flushes pending metrics to exporter. -func (agg *metricsAggregator) flush(ctx context.Context) { - agg.mu.Lock() - defer agg.mu.Unlock() - agg.flushUnlocked(ctx) -} - -// flushUnlocked flushes without locking (caller must hold lock). -func (agg *metricsAggregator) flushUnlocked(ctx context.Context) { - if len(agg.batch) == 0 { - return - } - - // Copy batch and clear - metrics := make([]*telemetryMetric, len(agg.batch)) - copy(metrics, agg.batch) - agg.batch = agg.batch[:0] - - // Export asynchronously - go func() { - defer func() { - if r := recover(); r != nil { - // Log at trace level only - } - }() - agg.exporter.export(ctx, metrics) - }() -} - -// close stops the aggregator and flushes pending metrics. -func (agg *metricsAggregator) close(ctx context.Context) error { - close(agg.stopCh) - agg.flush(ctx) - return nil -} -``` +Telemetry is **disabled by default**. It collects driver usage metrics (latency, error types, feature flags) and exports them to the Databricks telemetry service. No SQL text, PII, or query results are ever collected. --- -### 3.6 telemetryExporter +## Architecture -**Purpose**: Export metrics to Databricks telemetry service. - -**Location**: `telemetry/exporter.go` - -#### Interface - -```go -package telemetry - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "net/http" - "time" -) - -// telemetryExporter exports metrics to Databricks service. -type telemetryExporter struct { - host string - httpClient *http.Client - circuitBreaker *circuitBreaker - cfg *config -} - -// telemetryMetric represents a metric to export. -type telemetryMetric struct { - metricType string - timestamp time.Time - workspaceID string - sessionID string - statementID string - latencyMs int64 - errorType string - tags map[string]interface{} -} - -// newTelemetryExporter creates a new exporter. -func newTelemetryExporter(host string, httpClient *http.Client, cfg *config) *telemetryExporter { - return &telemetryExporter{ - host: host, - httpClient: httpClient, - circuitBreaker: getCircuitBreakerManager().getCircuitBreaker(host), - cfg: cfg, - } -} - -// export exports metrics to Databricks service. -func (e *telemetryExporter) export(ctx context.Context, metrics []*telemetryMetric) { - // Swallow all errors - defer func() { - if r := recover(); r != nil { - // Log at trace level only - } - }() - - // Check circuit breaker - err := e.circuitBreaker.execute(ctx, func() error { - return e.doExport(ctx, metrics) - }) - - if err == ErrCircuitOpen { - // Drop metrics silently - return - } - - if err != nil { - // Log at trace level only - } -} - -// doExport performs the actual export with retries. -func (e *telemetryExporter) doExport(ctx context.Context, metrics []*telemetryMetric) error { - // Serialize metrics - data, err := json.Marshal(metrics) - if err != nil { - return fmt.Errorf("failed to marshal metrics: %w", err) - } - - // Determine endpoint - endpoint := fmt.Sprintf("https://%s/api/2.0/telemetry-ext", e.host) - - // Create request - req, err := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewReader(data)) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - - // Retry logic - maxRetries := e.cfg.maxRetries - for attempt := 0; attempt <= maxRetries; attempt++ { - if attempt > 0 { - // Exponential backoff - backoff := time.Duration(attempt) * e.cfg.retryDelay - time.Sleep(backoff) - } - - resp, err := e.httpClient.Do(req) - if err != nil { - if attempt == maxRetries { - return fmt.Errorf("failed after %d retries: %w", maxRetries, err) - } - continue - } - - resp.Body.Close() - - // Check status code - if resp.StatusCode >= 200 && resp.StatusCode < 300 { - return nil // Success - } - - // Check if retryable - if !isRetryableStatus(resp.StatusCode) { - return fmt.Errorf("non-retryable status: %d", resp.StatusCode) - } - - if attempt == maxRetries { - return fmt.Errorf("failed after %d retries: status %d", maxRetries, resp.StatusCode) - } - } - - return nil -} - -// isRetryableStatus returns true if HTTP status is retryable. -func isRetryableStatus(status int) bool { - return status == 429 || status == 503 || status >= 500 -} ``` - ---- - -## 4. Data Collection - -### 4.1 Metric Tags - -All metric tags are defined in a centralized location for maintainability and security. - -**Location**: `telemetry/tags.go` - -```go -package telemetry - -// Tag names for connection metrics -const ( - TagWorkspaceID = "workspace.id" - TagSessionID = "session.id" - TagDriverVersion = "driver.version" - TagDriverOS = "driver.os" - TagDriverRuntime = "driver.runtime" - TagServerAddress = "server.address" // Not exported to Databricks -) - -// Tag names for statement metrics -const ( - TagStatementID = "statement.id" - TagResultFormat = "result.format" - TagResultChunkCount = "result.chunk_count" - TagResultBytesDownloaded = "result.bytes_downloaded" - TagCompressionEnabled = "result.compression_enabled" - TagPollCount = "poll.count" - TagPollLatency = "poll.latency_ms" -) - -// Tag names for error metrics -const ( - TagErrorType = "error.type" - TagErrorCode = "error.code" -) - -// Feature flag tags -const ( - TagFeatureCloudFetch = "feature.cloudfetch" - TagFeatureLZ4 = "feature.lz4" - TagFeatureDirectResults = "feature.direct_results" -) - -// tagExportScope defines where a tag can be exported. -type tagExportScope int - -const ( - exportNone tagExportScope = 0 - exportLocal = 1 << iota - exportDatabricks - exportAll = exportLocal | exportDatabricks -) - -// tagDefinition defines a metric tag and its export scope. -type tagDefinition struct { - name string - exportScope tagExportScope - description string - required bool -} - -// connectionTags returns tags allowed for connection events. -func connectionTags() []tagDefinition { - return []tagDefinition{ - {TagWorkspaceID, exportDatabricks, "Databricks workspace ID", true}, - {TagSessionID, exportDatabricks, "Connection session ID", true}, - {TagDriverVersion, exportAll, "Driver version", false}, - {TagDriverOS, exportAll, "Operating system", false}, - {TagDriverRuntime, exportAll, "Go runtime version", false}, - {TagFeatureCloudFetch, exportDatabricks, "CloudFetch enabled", false}, - {TagFeatureLZ4, exportDatabricks, "LZ4 compression enabled", false}, - {TagServerAddress, exportLocal, "Server address (local only)", false}, - } -} - -// statementTags returns tags allowed for statement events. -func statementTags() []tagDefinition { - return []tagDefinition{ - {TagStatementID, exportDatabricks, "Statement ID", true}, - {TagSessionID, exportDatabricks, "Session ID", true}, - {TagResultFormat, exportDatabricks, "Result format", false}, - {TagResultChunkCount, exportDatabricks, "Chunk count", false}, - {TagResultBytesDownloaded, exportDatabricks, "Bytes downloaded", false}, - {TagCompressionEnabled, exportDatabricks, "Compression enabled", false}, - {TagPollCount, exportDatabricks, "Poll count", false}, - {TagPollLatency, exportDatabricks, "Poll latency", false}, - } -} - -// shouldExportToDatabricks returns true if tag should be exported to Databricks. -func shouldExportToDatabricks(metricType, tagName string) bool { - var tags []tagDefinition - switch metricType { - case "connection": - tags = connectionTags() - case "statement": - tags = statementTags() - default: - return false - } - - for _, tag := range tags { - if tag.name == tagName { - return tag.exportScope&exportDatabricks != 0 - } - } - return false -} +Driver operations + → Interceptor (collect + tag metrics via context) + → metricsAggregator (aggregate by statement, batch by size/time) + → telemetryExporter (HTTP POST with retries) + → circuitBreaker (protect against endpoint failures) + → /telemetry-ext ``` -### 4.2 Metric Collection Points - -```go -// In connection.go - Connection open -func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { - // ... existing code ... - - conn := &conn{ - id: generateConnectionID(), - cfg: c.cfg, - client: tclient, - session: session, - } - - // Initialize telemetry if enabled - if c.cfg.telemetryEnabled { - conn.telemetry = newTelemetryInterceptor(conn.id, c.cfg) - - // Record connection open - conn.telemetry.recordConnection(ctx, map[string]interface{}{ - TagWorkspaceID: c.cfg.WorkspaceID, - TagSessionID: conn.id, - TagDriverVersion: c.cfg.DriverVersion, - TagDriverOS: runtime.GOOS, - TagDriverRuntime: runtime.Version(), - }) - } - - return conn, nil -} - -// In statement.go - Statement execution -func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { - statementID := generateStatementID() - - // Start telemetry - if s.conn.telemetry != nil { - ctx = s.conn.telemetry.beforeExecute(ctx, statementID) - defer func() { - s.conn.telemetry.afterExecute(ctx, err) - s.conn.telemetry.completeStatement(ctx, statementID, err != nil) - }() - } - - // ... existing code ... - - // Add telemetry tags during execution - if s.conn.telemetry != nil { - s.conn.telemetry.addTag(ctx, TagResultFormat, "cloudfetch") - s.conn.telemetry.addTag(ctx, TagResultChunkCount, chunkCount) - } - - return rows, nil -} -``` +One `metricsAggregator` + `telemetryExporter` is shared per host (managed by `clientManager` with reference counting). This prevents rate limiting when many connections open to the same host. --- -## 5. Export Mechanism +## Components -### 5.1 Data Model +### `Interceptor` (`interceptor.go`) +Exported hooks called by the driver package: -```go -// telemetryPayload is the JSON structure sent to Databricks. -type telemetryPayload struct { - Metrics []*exportedMetric `json:"metrics"` -} - -// exportedMetric is a single metric in the payload. -type exportedMetric struct { - MetricType string `json:"metric_type"` - Timestamp string `json:"timestamp"` // RFC3339 - WorkspaceID string `json:"workspace_id,omitempty"` - SessionID string `json:"session_id,omitempty"` - StatementID string `json:"statement_id,omitempty"` - LatencyMs int64 `json:"latency_ms,omitempty"` - ErrorType string `json:"error_type,omitempty"` - Tags map[string]interface{} `json:"tags,omitempty"` -} - -// toExportedMetric converts internal metric to exported format. -func (m *telemetryMetric) toExportedMetric() *exportedMetric { - // Filter tags based on export scope - filteredTags := make(map[string]interface{}) - for k, v := range m.tags { - if shouldExportToDatabricks(m.metricType, k) { - filteredTags[k] = v - } - } +| Method | When called | +|--------|-------------| +| `BeforeExecute(ctx, sessionID, statementID)` | Before statement execution | +| `BeforeExecuteWithTime(ctx, sessionID, statementID, startTime)` | When statement ID is known only after execution starts | +| `AfterExecute(ctx, err)` | After execution completes | +| `AddTag(ctx, key, value)` | During execution to attach metadata | +| `CompleteStatement(ctx, statementID, failed)` | After rows are fully consumed | +| `RecordOperation(ctx, sessionID, opType, latencyMs)` | For non-statement operations | +| `Close(ctx)` | On connection close — **synchronous blocking** flush; waits for the HTTP export to complete before returning, matching JDBC's `flush(true).get()` behavior | - return &exportedMetric{ - MetricType: m.metricType, - Timestamp: m.timestamp.Format(time.RFC3339), - WorkspaceID: m.workspaceID, - SessionID: m.sessionID, - StatementID: m.statementID, - LatencyMs: m.latencyMs, - ErrorType: m.errorType, - Tags: filteredTags, - } -} -``` +### `metricsAggregator` (`aggregator.go`) +- Aggregates statement metrics by `statementID` (total latency, chunk count, poll count) +- Flush behavior by metric type: + - `"connection"` — flushes immediately (lifecycle event, must not be lost) + - `"operation"` — batched, flushes when batch reaches `BatchSize` + - `"statement"` — accumulated until `CompleteStatement`, then batched +- Background flush ticker runs every `FlushInterval` +- `flushUnlocked` exports asynchronously; semaphore (8 slots) bounds concurrent goroutines +- `flushSync` exports synchronously; used by `Interceptor.Close` to block until delivery +- `close` is idempotent via `sync.Once`; cancels in-flight periodic exports -### 5.2 Export Endpoints +### `telemetryExporter` (`exporter.go`) +- HTTP POST to `/telemetry-ext` +- Exponential backoff retries (default: 3 retries, 100ms base delay) +- Retryable statuses: 429, 503, 5xx +- All errors and panics are swallowed; logged at TRACE level only -```go -const ( - // telemetryEndpointAuth is used when connection has auth token - telemetryEndpointAuth = "/api/2.0/telemetry-ext" +### `circuitBreaker` (`circuitbreaker.go`) +Sliding window (30 calls), opens when failure rate ≥ 50% over ≥ 20 calls. Recovers after 30s via half-open state (3 successful test calls to close). One circuit breaker per host. - // telemetryEndpointUnauth is used when no auth available - telemetryEndpointUnauth = "/api/2.0/telemetry-unauth" -) +### `clientManager` (`manager.go`) +Singleton. Maintains one `(aggregator + exporter)` per host with reference counting. Last connection to close triggers aggregator shutdown and final flush. -// selectEndpoint chooses the appropriate endpoint based on auth. -func (e *telemetryExporter) selectEndpoint() string { - if e.cfg.hasAuth { - return telemetryEndpointAuth - } - return telemetryEndpointUnauth -} -``` +### `featureFlagCache` (`featureflag.go`) +Caches the server-side feature flag per host for 15 minutes. Synchronous fetch on first connection (or cache expiry); concurrent callers receive the stale cached value immediately. --- -## 6. Configuration - -### 6.1 Configuration Structure - -```go -package telemetry - -import "time" - -// Config holds telemetry configuration. -type Config struct { - // Enabled controls whether telemetry is active - Enabled bool - - // EnableTelemetry indicates user wants telemetry enabled. - // Follows client > server > default priority: if set by the client it takes - // precedence; otherwise the server feature flag and defaults are consulted. - EnableTelemetry bool - - // BatchSize is the number of metrics to batch before flushing - BatchSize int - - // FlushInterval is how often to flush metrics - FlushInterval time.Duration - - // MaxRetries is the maximum number of retry attempts - MaxRetries int - - // RetryDelay is the base delay between retries - RetryDelay time.Duration - - // CircuitBreakerEnabled enables circuit breaker protection - CircuitBreakerEnabled bool - - // CircuitBreakerThreshold is failures before opening circuit - CircuitBreakerThreshold int - - // CircuitBreakerTimeout is time before retrying after open - CircuitBreakerTimeout time.Duration -} - -// DefaultConfig returns default telemetry configuration. -// Note: Telemetry is disabled by default and requires explicit opt-in. -func DefaultConfig() *Config { - return &Config{ - Enabled: false, // Disabled by default, requires explicit opt-in - EnableTelemetry: false, - BatchSize: 100, - FlushInterval: 5 * time.Second, - MaxRetries: 3, - RetryDelay: 100 * time.Millisecond, - CircuitBreakerEnabled: true, - CircuitBreakerThreshold: 5, - CircuitBreakerTimeout: 1 * time.Minute, - } -} -``` - -### 6.2 Configuration from Connection Parameters - -```go -// ParseTelemetryConfig extracts telemetry config from connection parameters. -func ParseTelemetryConfig(params map[string]string) *Config { - cfg := DefaultConfig() - - // Check for enableTelemetry flag (follows client > server > default priority) - if v, ok := params["enableTelemetry"]; ok { - if v == "true" || v == "1" { - cfg.EnableTelemetry = true - } else if v == "false" || v == "0" { - cfg.EnableTelemetry = false - } - } - - if v, ok := params["telemetry_batch_size"]; ok { - if size, err := strconv.Atoi(v); err == nil && size > 0 { - cfg.BatchSize = size - } - } - - if v, ok := params["telemetry_flush_interval"]; ok { - if duration, err := time.ParseDuration(v); err == nil { - cfg.FlushInterval = duration - } - } - - return cfg -} -``` - -### 6.3 Feature Flag Integration - -```go -// checkFeatureFlag checks if telemetry is enabled server-side. -func checkFeatureFlag(ctx context.Context, host string, httpClient *http.Client) (bool, error) { - endpoint := fmt.Sprintf("https://%s/api/2.0/feature-flags", host) - - req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil) - if err != nil { - return false, err - } - - // Add query parameters - q := req.URL.Query() - q.Add("flags", "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver") - req.URL.RawQuery = q.Encode() - - resp, err := httpClient.Do(req) - if err != nil { - return false, err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return false, fmt.Errorf("feature flag check failed: %d", resp.StatusCode) - } - - var result struct { - Flags map[string]bool `json:"flags"` - } - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return false, err - } - - // Parse flag response - flagValue := result.Flags["databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver"] - - response := &featureFlagResponse{ - Enabled: false, - RolloutPercentage: 0, - } - - // Handle both boolean and object responses for backward compatibility - switch v := flagValue.(type) { - case bool: - response.Enabled = v - if v { - response.RolloutPercentage = 100 - } - case map[string]interface{}: - if enabled, ok := v["enabled"].(bool); ok { - response.Enabled = enabled - } - if rollout, ok := v["rollout_percentage"].(float64); ok { - response.RolloutPercentage = int(rollout) - } - } - - return response, nil -} - -// isInRollout checks if this connection is in the rollout percentage. -// Uses consistent hashing based on workspace ID for stable rollout. -func isInRollout(workspaceID string, rolloutPercentage int) bool { - if rolloutPercentage >= 100 { - return true - } - if rolloutPercentage <= 0 { - return false - } - - // Use consistent hashing based on workspace ID - h := fnv.New32a() - h.Write([]byte(workspaceID)) - hash := h.Sum32() - - return int(hash%100) < rolloutPercentage -} -``` - -#### Synchronous Fetch Behavior - -**Feature flag fetching is synchronous** and may block driver initialization. - -**Key Characteristics:** -- 10-second HTTP timeout per request -- Uses RetryableClient (4 retries, exponential backoff 1s-30s) -- 15-minute cache minimizes fetch frequency -- Thundering herd protection (only one fetch per host at a time) - -**When It Blocks:** -- First connection to host: blocks for HTTP fetch (up to ~70s with retries) -- Cache expiry (every 15 min): first caller blocks, others return stale cache -- Concurrent callers: only first blocks, others return stale cache immediately - -**Why synchronous:** Simple, deterministic, 99% cache hit rate, matches JDBC driver. - -### 6.4 Config Overlay Pattern - -**UPDATE (Phase 4-5):** The telemetry system now uses a **config overlay pattern** that provides a consistent, clear priority model. This pattern is designed to be reusable across all driver configurations. - -#### Config Overlay Priority (highest to lowest): - -1. **Client Config** - Explicitly set by user (overrides server) -2. **Server Config** - Feature flag controls when client doesn't specify -3. **Fail-Safe Default** - Disabled when server unavailable - -This approach eliminates the need for special bypass flags like `forceEnableTelemetry` because client config naturally has priority. - -#### Implementation: - -```go -// EnableTelemetry is a pointer to distinguish three states: -// - nil: not set by client (use server feature flag) -// - true: client wants enabled (overrides server) -// - false: client wants disabled (overrides server) -type Config struct { - EnableTelemetry *bool - // ... other fields -} - -// isTelemetryEnabled implements config overlay -func isTelemetryEnabled(ctx context.Context, cfg *Config, host string, httpClient *http.Client) bool { - // Priority 1: Client explicitly set (overrides everything) - if cfg.EnableTelemetry != nil { - return *cfg.EnableTelemetry - } +## Export Format - // Priority 2: Check server-side feature flag - flagCache := getFeatureFlagCache() - serverEnabled, err := flagCache.isTelemetryEnabled(ctx, host, httpClient) - if err != nil { - // Priority 3: Fail-safe default (disabled) - return false - } +Metrics are sent as a `TelemetryRequest` with a `protoLogs` array. Each entry is a JSON-encoded `TelemetryFrontendLog` aligned with the `OssSqlDriverTelemetryLog` proto schema: - return serverEnabled +```json +{ + "uploadTime": 1234567890000, + "items": [], + "protoLogs": [ + "{\"frontend_log_event_id\":\"20240101120000-a1b2c3d4e5f6g7h8\",\"context\":{...},\"entry\":{\"sql_driver_log\":{\"session_id\":\"...\",\"sql_statement_id\":\"...\",\"operation_latency_ms\":42,\"sql_operation\":{\"chunk_details\":{\"total_chunks_iterated\":3},...}}}}" + ] } ``` -#### Configuration Behavior Matrix: +Event IDs are generated with `crypto/rand` + hex encoding (no modulo bias). -| Client Sets | Server Returns | Result | Explanation | -|-------------|----------------|--------|-------------| -| `true` | `false` | **`true`** | Client overrides server | -| `false` | `true` | **`false`** | Client overrides server | -| `true` | error | **`true`** | Client overrides server error | -| unset | `true` | **`true`** | Use server config | -| unset | `false` | **`false`** | Use server config | -| unset | error | **`false`** | Fail-safe default | - -#### Configuration Parameter Summary: - -| Parameter | Value | Behavior | Use Case | -|-----------|-------|----------|----------| -| `enableTelemetry=true` | Client forces enabled | Always send telemetry (overrides server) | Testing, debugging, opt-in users | -| `enableTelemetry=false` | Client forces disabled | Never send telemetry (overrides server) | Privacy-conscious users, opt-out | -| *(not set)* | Use server flag | Server controls via feature flag | Default behavior - Databricks-controlled rollout | - -#### Benefits of Config Overlay: - -- ✅ **Simpler**: Client > Server > Default (3 clear layers) -- ✅ **Consistent**: Same pattern can be used for all driver configs -- ✅ **No bypass flags**: Client config naturally has priority -- ✅ **Reusable**: General `ConfigValue[T]` type in `internal/config/overlay.go` -- ✅ **Type-safe**: Uses Go generics for any config type - -#### General Config Overlay System: - -A reusable config overlay system is available in `internal/config/overlay.go`: - -```go -// Generic config value that supports overlay pattern -type ConfigValue[T any] struct { - value *T // nil = unset, non-nil = client set -} - -// Parse from connection params -cv := ParseBoolConfigValue(params, "enableFeature") - -// Resolve with overlay priority -result := cv.ResolveWithContext(ctx, host, httpClient, serverResolver, defaultValue) -``` - -**Note**: A general `ConfigValue[T]` implementation is available in `internal/config/overlay.go` for extending this pattern to other driver configurations. - ---- - -## 7. Privacy & Compliance - -### 7.1 Data Privacy - -**Never Collected**: -- ❌ SQL query text -- ❌ Query results or data values -- ❌ Table/column names -- ❌ User identities - -**Always Collected**: -- ✅ Operation latency -- ✅ Error codes (not messages) -- ✅ Feature flags (boolean) -- ✅ Statement IDs (UUIDs) -- ✅ Driver version and runtime info - -### 7.2 Tag Filtering - -```go -// filterTagsForExport removes tags that shouldn't be sent to Databricks. -func filterTagsForExport(metricType string, tags map[string]interface{}) map[string]interface{} { - filtered := make(map[string]interface{}) - - for key, value := range tags { - if shouldExportToDatabricks(metricType, key) { - filtered[key] = value - } - // Tags not in allowlist are silently dropped - } - - return filtered -} -``` - ---- - -## 8. Error Handling - -### 8.1 Error Swallowing Strategy - -**Core Principle**: Every telemetry error must be swallowed with minimal logging. - -```go -// recoverAndLog recovers from panics and logs at trace level. -func recoverAndLog(operation string) { - if r := recover(); r != nil { - // Use TRACE level logging (not exposed to users by default) - // logger.Trace().Msgf("telemetry: %s panic: %v", operation, r) - } -} - -// Example usage -func (i *interceptor) afterExecute(ctx context.Context, err error) { - defer recoverAndLog("afterExecute") - - // Telemetry logic that might panic - // ... -} -``` - -### 8.2 Error Classification - -```go -// errorClassifier classifies errors as terminal or retryable. -type errorClassifier struct{} - -// isTerminalError returns true if error is terminal (non-retryable). -func isTerminalError(err error) bool { - if err == nil { - return false - } - - // Unwrap error to check underlying type - var httpErr *httpError - if errors.As(err, &httpErr) { - return isTerminalHTTPStatus(httpErr.statusCode) - } - - // Check error message patterns - errMsg := err.Error() - terminalPatterns := []string{ - "authentication failed", - "unauthorized", - "forbidden", - "not found", - "invalid request", - "syntax error", - } - - for _, pattern := range terminalPatterns { - if strings.Contains(strings.ToLower(errMsg), pattern) { - return true - } - } - - return false -} - -// isTerminalHTTPStatus returns true for non-retryable HTTP status codes. -func isTerminalHTTPStatus(status int) bool { - return status == 400 || status == 401 || status == 403 || status == 404 -} -``` - ---- - -## 9. Graceful Shutdown - -### 9.1 Shutdown Sequence - -```go -// In connection.go -func (c *conn) Close() error { - // Close telemetry before closing connection - if c.telemetry != nil { - // This is non-blocking and swallows errors - c.telemetry.close(context.Background()) - } - - // Release per-host resources - if c.cfg.telemetryEnabled { - getClientManager().releaseClient(c.cfg.Host) - getFeatureFlagCache().releaseContext(c.cfg.Host) - } - - // ... existing connection close logic ... - - return nil -} -``` - -### 9.2 Client Manager Shutdown - -The `clientManager` now includes a `shutdown()` method that provides graceful cleanup of all telemetry clients on application shutdown. This method: - -- Closes all active telemetry clients regardless of reference counts -- Logs warnings for any close failures -- Clears the clients map to prevent memory leaks -- Returns the last error encountered (if any) - -```go -// shutdown closes all telemetry clients and clears the manager. -// Integration points will be determined in Phase 4. -func (m *clientManager) shutdown() error { - m.mu.Lock() - defer m.mu.Unlock() - - var lastErr error - for host, holder := range m.clients { - if err := holder.client.close(); err != nil { - logger.Logger.Warn().Str("host", host).Err(err).Msg("error closing telemetry client during shutdown") - lastErr = err - } - } - // Clear the map - m.clients = make(map[string]*clientHolder) - return lastErr -} -``` - -**Integration Options** (to be implemented in Phase 4): - -1. **Public API**: Export a `Shutdown()` function for applications to call during their shutdown sequence -2. **Driver Hook**: Integrate with `sql.DB.Close()` or driver cleanup mechanisms -3. **Signal Handler**: Call from application signal handlers (SIGTERM, SIGINT) - -### 9.3 Client Shutdown - -```go -// close shuts down the telemetry client gracefully. -func (c *telemetryClient) close() error { - defer recoverAndLog("client.close") - - // Stop background flush - close(c.stopCh) - - // Flush pending metrics with timeout - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - c.aggregator.flush(ctx) - - return nil -} -``` +System info (OS name/version/arch, Go runtime version) is read once and cached via `sync.Once`. --- -## 10. Testing Strategy - -### 10.1 Unit Tests - -```go -// telemetry/interceptor_test.go -func TestInterceptor_BeforeAfterExecute(t *testing.T) { - agg := &mockAggregator{} - interceptor := &interceptor{ - aggregator: agg, - enabled: true, - } - - ctx := context.Background() - ctx = interceptor.beforeExecute(ctx, "stmt-123") - - // Verify metric context is attached - mc := getMetricContext(ctx) - assert.NotNil(t, mc) - assert.Equal(t, "stmt-123", mc.statementID) - - // Simulate execution - time.Sleep(10 * time.Millisecond) - - interceptor.afterExecute(ctx, nil) +## Configuration - // Verify metric was recorded - assert.Equal(t, 1, len(agg.metrics)) - assert.True(t, agg.metrics[0].latencyMs >= 10) -} - -func TestInterceptor_ErrorHandling(t *testing.T) { - agg := &mockAggregator{ - shouldPanic: true, - } - interceptor := &interceptor{ - aggregator: agg, - enabled: true, - } - - ctx := context.Background() - ctx = interceptor.beforeExecute(ctx, "stmt-123") - - // Should not panic even if aggregator panics - assert.NotPanics(t, func() { - interceptor.afterExecute(ctx, nil) - }) -} -``` - -### 10.2 Integration Tests - -```go -// telemetry/integration_test.go -func TestTelemetry_EndToEnd(t *testing.T) { - if testing.Short() { - t.Skip("skipping integration test") - } - - // Create mock server - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/api/2.0/telemetry-ext" { - w.WriteHeader(http.StatusOK) - } - })) - defer server.Close() - - // Create connection with telemetry - cfg := &config.Config{ - Host: server.URL, - TelemetryEnabled: true, - } - - conn, err := connect(cfg) - require.NoError(t, err) - defer conn.Close() - - // Execute statement - _, err = conn.QueryContext(context.Background(), "SELECT 1") - require.NoError(t, err) - - // Wait for flush - time.Sleep(6 * time.Second) - - // Verify metrics were sent to server - // (check server mock received requests) -} -``` +DSN parameter: `enableTelemetry=true|false` (parsed via `strconv.ParseBool`) -### 10.3 Benchmark Tests +Priority (highest to lowest): +1. Client DSN setting (`enableTelemetry=true/false`) — overrides server +2. Server feature flag (`databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver`) +3. Default: disabled -```go -// telemetry/benchmark_test.go -func BenchmarkInterceptor_Overhead(b *testing.B) { - agg := &metricsAggregator{ - // ... initialized ... - } - interceptor := &interceptor{ - aggregator: agg, - enabled: true, - } - - ctx := context.Background() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - ctx = interceptor.beforeExecute(ctx, "stmt-123") - interceptor.afterExecute(ctx, nil) - } -} - -func BenchmarkInterceptor_Disabled(b *testing.B) { - agg := &metricsAggregator{ - // ... initialized ... - } - interceptor := &interceptor{ - aggregator: agg, - enabled: false, // Disabled - } - - ctx := context.Background() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - ctx = interceptor.beforeExecute(ctx, "stmt-123") - interceptor.afterExecute(ctx, nil) - } -} -``` +Batch size, flush interval, retry count, and retry delay are also configurable. --- -## 11. Implementation Checklist - -**Strategy**: Build infrastructure bottom-up: Circuit Breaker → Export (POST to endpoint) → Opt-In Configuration → Collection & Aggregation → Driver Integration. This allows unit testing each layer before adding metric collection. - -**JIRA Tickets**: -- **PECOBLR-1143**: Phases 1-5 (Core Infrastructure → Opt-In Configuration) - - **PECOBLR-1381**: Phase 6 (Collection & Aggregation) - subtask - - **PECOBLR-1382**: Phase 7 (Driver Integration) - subtask - -### Phase 1: Core Infrastructure ✅ COMPLETED -- [x] Create `telemetry` package structure -- [x] Implement `config.go` with configuration types (basic structure) -- [x] Implement `tags.go` with tag definitions and filtering -- [x] Add unit tests for configuration and tags - -### Phase 2: Per-Host Management ✅ COMPLETED -- [x] Implement `featureflag.go` with caching and reference counting -- [x] Implement `manager.go` for client management - - [x] Thread-safe singleton pattern with per-host client holders - - [x] Reference counting for automatic cleanup - - [x] Error handling for client start failures - - [x] Shutdown method for graceful application shutdown - - [x] Comprehensive documentation on thread-safety and connection sharing -- [x] Implement `client.go` with minimal telemetryClient stub - - [x] Thread-safe start() and close() methods - - [x] Mutex protection for state flags - - [x] Detailed documentation on concurrent access requirements -- [x] Add comprehensive unit tests for all components - - [x] Singleton pattern verification - - [x] Reference counting (increment/decrement/cleanup) - - [x] Concurrent access tests (100+ goroutines) - - [x] Shutdown scenarios (empty, with active refs, multiple hosts) - - [x] Race detector tests passing - -### Phase 3: Circuit Breaker ✅ COMPLETED -- [x] Implement `circuitbreaker.go` with state machine - - [x] Implement circuit breaker states (Closed, Open, Half-Open) - - [x] Implement circuitBreakerManager singleton per host - - [x] Add configurable thresholds and timeout - - [x] Implement execute() method with state transitions - - [x] Implement failure/success tracking with sliding window algorithm -- [x] Add comprehensive unit tests - - [x] Test state transitions (Closed → Open → Half-Open → Closed) - - [x] Test failure/success counting - - [x] Test timeout and retry logic - - [x] Test per-host circuit breaker isolation - - [x] Test concurrent access - -### Phase 4: Export Infrastructure ✅ COMPLETED -- [x] Implement `exporter.go` with retry logic - - [x] Implement HTTP POST to telemetry endpoint (/api/2.0/telemetry-ext) - - [x] Implement retry logic with exponential backoff - - [x] Implement tag filtering for export (shouldExportToDatabricks) - - [x] Integrate with circuit breaker - - [x] Add error swallowing - - [x] Implement toExportedMetric() conversion - - [x] Implement telemetryPayload JSON structure -- [x] Add unit tests for export logic - - [x] Test HTTP request construction - - [x] Test retry logic (with mock HTTP responses) - - [x] Test circuit breaker integration - - [x] Test tag filtering - - [x] Test error swallowing -- [x] Add integration tests with mock HTTP server - - [x] Test successful export - - [x] Test error scenarios (4xx, 5xx) - - [x] Test retry behavior (exponential backoff) - - [x] Test circuit breaker opening/closing - - [x] Test context cancellation - -### Phase 5: Opt-In Configuration Integration ✅ COMPLETED -- [x] Implement `isTelemetryEnabled()` with priority-based logic in config.go - - [x] Priority 1 (client): EnableTelemetry=true → enable regardless of server flag - - [x] Priority 2 (client): EnableTelemetry=false → disable regardless of server flag - - [x] Priority 3 (server): Server feature flag controls when client preference unset - - [x] Priority 4 (default): Disabled if no flags set and server check fails -- [x] Integrate feature flag cache with opt-in logic - - [x] Wire up isTelemetryEnabled() to call featureFlagCache.isTelemetryEnabled() - - [x] Implement fallback behavior on errors (return cached value or false) - - [x] Add proper error handling -- [x] Add unit tests for opt-in priority logic - - [x] Test enableTelemetry=false (always disabled, explicit opt-out) - - [x] Test enableTelemetry=true with server flag enabled - - [x] Test enableTelemetry=true with server flag disabled - - [x] Test default behavior (server flag controls) - - [x] Test error scenarios (server unreachable, use cached value) -- [x] Add integration tests with mock feature flag server - - [x] Test opt-in priority with mock server - - [x] Test server error handling - - [x] Test unreachable server scenarios - -### Phase 6: Collection & Aggregation (PECOBLR-1381) ✅ COMPLETED -- [x] Implement `interceptor.go` for metric collection - - [x] Implement beforeExecute() and afterExecute() hooks - - [x] Implement context-based metric tracking with metricContext - - [x] Implement latency measurement (startTime, latencyMs calculation) - - [x] Add tag collection methods (addTag) - - [x] Implement error swallowing with panic recovery -- [x] Implement `aggregator.go` for batching - - [x] Implement statement-level aggregation (statementMetrics) - - [x] Implement batch size and flush interval logic - - [x] Implement background flush goroutine (flushLoop) - - [x] Add thread-safe metric recording - - [x] Implement completeStatement() for final aggregation -- [x] Implement error classification in `errors.go` - - [x] Implement error type classification (terminal vs retryable) - - [x] Implement HTTP status code classification - - [x] Add error pattern matching - - [x] Implement isTerminalError() function -- [x] Update `client.go` to integrate aggregator - - [x] Wire up aggregator with exporter - - [x] Implement background flush timer - - [x] Update start() and close() methods -- [x] Add unit tests for collection and aggregation - - [x] Test interceptor metric collection and latency tracking - - [x] Test aggregation logic - - [x] Test batch flushing (size-based and time-based) - - [x] Test error classification - - [x] Test client with aggregator integration - -### Phase 7: Driver Integration ✅ COMPLETED -- [x] Add telemetry initialization to `connection.go` - - [x] Call isTelemetryEnabled() at connection open via InitializeForConnection() - - [x] Initialize telemetry client via clientManager.getOrCreateClient() - - [x] Increment feature flag cache reference count - - [x] Store telemetry interceptor in connection -- [x] Add telemetry configuration to UserConfig - - [x] EnableTelemetry field (client > server > default priority) - - [x] DSN parameter parsing - - [x] DeepCopy support -- [x] Add cleanup in `Close()` methods - - [x] Release client manager reference in connection.Close() - - [x] Release feature flag cache reference via ReleaseForConnection() - - [x] Flush pending metrics before close -- [x] Export necessary types and methods - - [x] Export Interceptor type - - [x] Export GetInterceptor() and Close() methods - - [x] Create driver integration helpers -- [x] Basic integration tests - - [x] Test compilation with telemetry - - [x] Test no breaking changes to existing tests - - [x] Test graceful handling when disabled -- [x] Statement execution hooks - - [x] Add beforeExecute() hook to QueryContext - - [x] Add afterExecute() and completeStatement() hooks to QueryContext - - [x] Add beforeExecute() hook to ExecContext - - [x] Add afterExecute() and completeStatement() hooks to ExecContext - - [x] Use operation handle GUID as statement ID - -### Phase 8: Testing & Validation -- [ ] Run benchmark tests - - [ ] Measure overhead when enabled - - [ ] Measure overhead when disabled - - [ ] Ensure <1% overhead when enabled -- [ ] Perform load testing with concurrent connections - - [ ] Test 100+ concurrent connections - - [ ] Verify per-host client sharing - - [ ] Verify no rate limiting with per-host clients -- [ ] Validate graceful shutdown - - [ ] Test reference counting cleanup - - [ ] Test final flush on shutdown - - [ ] Test shutdown method works correctly -- [ ] Test circuit breaker behavior - - [ ] Test circuit opening on repeated failures - - [ ] Test circuit recovery after timeout - - [ ] Test metrics dropped when circuit open -- [ ] Test opt-in priority logic end-to-end - - [ ] Verify forceEnableTelemetry works in real driver - - [ ] Verify enableTelemetry works in real driver - - [ ] Verify server flag integration works -- [ ] Verify privacy compliance - - [ ] Verify no SQL queries collected - - [ ] Verify no PII collected - - [ ] Verify tag filtering works (shouldExportToDatabricks) - -### Phase 9: Partial Launch Preparation -- [ ] Document `forceEnableTelemetry` and `enableTelemetry` flags -- [ ] Create internal testing plan for Phase 1 (use forceEnableTelemetry=true) -- [ ] Prepare beta opt-in documentation for Phase 2 (use enableTelemetry=true) -- [ ] Set up monitoring for rollout health metrics -- [ ] Document rollback procedures (set server flag to false) - -### Phase 10: Documentation -- [ ] Document configuration options in README -- [ ] Add examples for opt-in flags -- [ ] Document partial launch strategy and phases -- [ ] Document metric tags and their meanings -- [ ] Create troubleshooting guide -- [ ] Document architecture and design decisions - ---- - -## 12. References - -### 12.1 Go Standards -- [context package](https://pkg.go.dev/context) -- [database/sql/driver](https://pkg.go.dev/database/sql/driver) -- [net/http](https://pkg.go.dev/net/http) -- [sync package](https://pkg.go.dev/sync) -- [Effective Go](https://go.dev/doc/effective_go) - -### 12.2 Existing Code References - -**Databricks SQL Go Driver**: -- `connection.go`: Connection management -- `connector.go`: Connection factory -- `internal/config/config.go`: Configuration patterns -- `internal/client/client.go`: HTTP client patterns - -**JDBC Driver** (reference implementation): -- `TelemetryClient.java:15`: Batching and flush logic -- `TelemetryClientFactory.java:27`: Per-host client management -- `CircuitBreakerManager.java:25`: Circuit breaker pattern -- `DatabricksDriverFeatureFlagsContextFactory.java:27`: Feature flag caching - ---- - -## Summary - -This **telemetry design for Go** follows idiomatic Go patterns: +## Privacy -1. **Idiomatic Go**: Uses context, interfaces, goroutines, and channels -2. **Minimal changes**: Wraps existing driver operations with interceptors -3. **Type safety**: Strong typing with exported/unexported types -4. **Error handling**: All errors swallowed, uses defer/recover patterns -5. **Concurrency**: Thread-safe using sync primitives (RWMutex, atomic) -6. **Testing**: Comprehensive unit, integration, and benchmark tests -7. **Standard library**: Minimal external dependencies +Only the following fields are exported: +- Session ID, statement ID, workspace ID +- Operation latency, error type (no message/stack trace) +- Chunk count, poll count +- Driver version, OS name, Go runtime version, result format -The design enables collecting valuable usage metrics while maintaining Go best practices and ensuring zero impact on driver reliability and performance. +Tags not in this allowlist are silently dropped before export. diff --git a/telemetry/aggregator.go b/telemetry/aggregator.go index 15323f08..38dd636c 100644 --- a/telemetry/aggregator.go +++ b/telemetry/aggregator.go @@ -8,6 +8,17 @@ import ( "github.com/databricks/databricks-sql-go/logger" ) +const ( + exportWorkerCount = 10 // Fixed worker pool size — matches JDBC's newFixedThreadPool(10) + exportQueueSize = 1000 // Buffered queue — effectively unbounded for normal workloads +) + +// exportJob holds a batch of metrics pending async export. +type exportJob struct { + ctx context.Context + metrics []*telemetryMetric +} + // metricsAggregator aggregates metrics by statement and batches for export. type metricsAggregator struct { mu sync.RWMutex @@ -21,10 +32,10 @@ type metricsAggregator struct { stopCh chan struct{} flushTimer *time.Ticker - closeOnce sync.Once - ctx context.Context // Cancellable context for in-flight exports - cancel context.CancelFunc // Cancels ctx on close - exportSem chan struct{} // Bounds concurrent export goroutines + closeOnce sync.Once + ctx context.Context // Cancellable context — cancelled on close to stop workers + cancel context.CancelFunc + exportQueue chan exportJob // Worker queue; drop batch only when full (matches JDBC LinkedBlockingQueue) } // statementMetrics holds aggregated metrics for a statement. @@ -41,7 +52,7 @@ type statementMetrics struct { // newMetricsAggregator creates a new metrics aggregator. func newMetricsAggregator(exporter *telemetryExporter, cfg *Config) *metricsAggregator { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(context.Background()) //nolint:gosec // cancel stored in agg.cancel and called on close agg := &metricsAggregator{ statements: make(map[string]*statementMetrics), batch: make([]*telemetryMetric, 0, cfg.BatchSize), @@ -51,7 +62,12 @@ func newMetricsAggregator(exporter *telemetryExporter, cfg *Config) *metricsAggr stopCh: make(chan struct{}), ctx: ctx, cancel: cancel, - exportSem: make(chan struct{}, 8), // Bound to 8 concurrent exports + exportQueue: make(chan exportJob, exportQueueSize), + } + + // Start fixed worker pool — matches JDBC's newFixedThreadPool(10) + for i := 0; i < exportWorkerCount; i++ { + go agg.exportWorker() } // Start background flush timer @@ -60,12 +76,27 @@ func newMetricsAggregator(exporter *telemetryExporter, cfg *Config) *metricsAggr return agg } +// exportWorker processes export jobs from the queue until the aggregator is cancelled. +func (agg *metricsAggregator) exportWorker() { + for { + select { + case job, ok := <-agg.exportQueue: + if !ok { + return + } + agg.exporter.export(job.ctx, job.metrics) + case <-agg.ctx.Done(): + return + } + } +} + // recordMetric records a metric for aggregation. func (agg *metricsAggregator) recordMetric(ctx context.Context, metric *telemetryMetric) { // Swallow all errors defer func() { if r := recover(); r != nil { - logger.Debug().Msgf("telemetry: recordMetric panic: %v", r) + logger.Trace().Msgf("telemetry: recordMetric panic: %v", r) } }() @@ -74,10 +105,18 @@ func (agg *metricsAggregator) recordMetric(ctx context.Context, metric *telemetr switch metric.metricType { case "connection": - // Emit connection events immediately: connection lifecycle events must be captured - // before the connection closes, as we won't have another opportunity to flush + // Connection events flush immediately: lifecycle events must be captured before + // the connection closes, as we won't have another opportunity to flush. agg.batch = append(agg.batch, metric) - if len(agg.batch) >= agg.batchSize { + agg.flushUnlocked(ctx) + + case "operation": + agg.batch = append(agg.batch, metric) + // Terminal operations (session/statement close) flush immediately so metrics + // are not lost if the connection closes before the next batch flush — matching + // JDBC behavior where CLOSE_STATEMENT and DELETE_SESSION trigger immediate export. + opType, _ := metric.tags["operation_type"].(string) + if isTerminalOperationType(opType) || len(agg.batch) >= agg.batchSize { agg.flushUnlocked(ctx) } @@ -118,8 +157,7 @@ func (agg *metricsAggregator) recordMetric(ctx context.Context, metric *telemetr case "error": // Check if terminal error if metric.errorType != "" && isTerminalError(&simpleError{msg: metric.errorType}) { - // Flush terminal errors immediately: terminal errors often lead to connection - // termination. If we wait for the next batch/timer flush, this data may be lost + // Flush terminal errors immediately agg.batch = append(agg.batch, metric) agg.flushUnlocked(ctx) } else { @@ -135,7 +173,7 @@ func (agg *metricsAggregator) recordMetric(ctx context.Context, metric *telemetr func (agg *metricsAggregator) completeStatement(ctx context.Context, statementID string, failed bool) { defer func() { if r := recover(); r != nil { - logger.Debug().Msgf("telemetry: completeStatement panic: %v", r) + logger.Trace().Msgf("telemetry: completeStatement panic: %v", r) } }() @@ -192,52 +230,63 @@ func (agg *metricsAggregator) flushLoop() { } } -// flush flushes pending metrics to exporter. +// flush flushes pending metrics to exporter asynchronously (fire-and-forget). func (agg *metricsAggregator) flush(ctx context.Context) { agg.mu.Lock() defer agg.mu.Unlock() agg.flushUnlocked(ctx) } -// flushUnlocked flushes without locking (caller must hold lock). +// flushSync flushes pending metrics synchronously, blocking until the export +// completes. Used on connection close to guarantee delivery before returning. +func (agg *metricsAggregator) flushSync(ctx context.Context) { + agg.mu.Lock() + if len(agg.batch) == 0 { + agg.mu.Unlock() + return + } + metrics := make([]*telemetryMetric, len(agg.batch)) + copy(metrics, agg.batch) + agg.batch = agg.batch[:0] + agg.mu.Unlock() + + agg.exporter.export(ctx, metrics) +} + +// flushUnlocked submits the current batch to the export worker queue (caller must hold lock). +// Matches JDBC's executorService.submit() — drops only when the queue is full, not merely +// when workers are busy. Workers drain the queue until the aggregator context is cancelled. func (agg *metricsAggregator) flushUnlocked(ctx context.Context) { if len(agg.batch) == 0 { return } - // Copy batch and clear metrics := make([]*telemetryMetric, len(agg.batch)) copy(metrics, agg.batch) agg.batch = agg.batch[:0] - // Acquire semaphore slot; skip export if already at capacity to prevent goroutine leaks select { - case agg.exportSem <- struct{}{}: + case agg.exportQueue <- exportJob{ctx: ctx, metrics: metrics}: default: - logger.Debug().Msg("telemetry: export semaphore full, dropping metrics batch") - return + // Queue full — drop batch silently (matches JDBC's RejectedExecutionException path) + logger.Debug().Msg("telemetry: export queue full, dropping metrics batch") } - - // Export asynchronously - go func() { - defer func() { - <-agg.exportSem - if r := recover(); r != nil { - logger.Debug().Msgf("telemetry: async export panic: %v", r) - } - }() - agg.exporter.export(ctx, metrics) - }() } // close stops the aggregator and flushes pending metrics. -// Safe to call multiple times — subsequent calls are no-ops for the stop/cancel step. +// Safe to call multiple times — subsequent calls are no-ops (closeOnce). +// +// Shutdown order matters: +// 1. Stop periodic flush (close stopCh) so no new async exports are queued. +// 2. Synchronously flush the current batch directly (flushSync bypasses the +// worker queue, so it works even after workers are stopped). +// 3. Cancel the aggregator context to stop the 10 export worker goroutines. func (agg *metricsAggregator) close(ctx context.Context) error { agg.closeOnce.Do(func() { - close(agg.stopCh) - agg.cancel() // Cancel in-flight periodic export goroutines + close(agg.stopCh) // Stop periodic flush loop + agg.flushSync(ctx) // Final flush — direct export, no workers needed + agg.cancel() // Stop export workers after final flush }) - agg.flush(ctx) return nil } diff --git a/telemetry/benchmark_test.go b/telemetry/benchmark_test.go new file mode 100644 index 00000000..d4687a08 --- /dev/null +++ b/telemetry/benchmark_test.go @@ -0,0 +1,315 @@ +// This file contains performance benchmarks and load tests for the telemetry package. +// Run manually — NOT part of CI: +// +// go test -bench=. -benchmem -run='^$' ./telemetry/... +// +// Benchmarks measure telemetry overhead on the hot path and validate that telemetry +// adds negligible latency to normal driver operations. +package telemetry + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" +) + +// BenchmarkInterceptor_Overhead_Enabled measures telemetry overhead on the hot path +// when telemetry is enabled. This is the worst-case overhead added to every statement. +func BenchmarkInterceptor_Overhead_Enabled(b *testing.B) { + cfg := DefaultConfig() + cfg.BatchSize = 1000 + cfg.FlushInterval = 10 * time.Minute // suppress periodic flush during bench + + exporter := newTelemetryExporter("localhost", "test-version", &http.Client{}, cfg) + agg := newMetricsAggregator(exporter, cfg) + defer agg.close(context.Background()) //nolint:errcheck + + interceptor := newInterceptor(agg, true) + ctx := context.Background() + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + ctx2 := interceptor.BeforeExecute(ctx, "session-bench", fmt.Sprintf("stmt-%d", i)) + interceptor.AfterExecute(ctx2, nil) + } +} + +// BenchmarkInterceptor_Overhead_Disabled measures the disabled-path overhead (baseline). +// The delta between Enabled and Disabled is the pure telemetry cost. +func BenchmarkInterceptor_Overhead_Disabled(b *testing.B) { + cfg := DefaultConfig() + exporter := newTelemetryExporter("localhost", "test-version", &http.Client{}, cfg) + agg := newMetricsAggregator(exporter, cfg) + defer agg.close(context.Background()) //nolint:errcheck + + interceptor := newInterceptor(agg, false) + ctx := context.Background() + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + ctx2 := interceptor.BeforeExecute(ctx, "session-bench", fmt.Sprintf("stmt-%d", i)) + interceptor.AfterExecute(ctx2, nil) + } +} + +// BenchmarkAggregator_RecordMetric measures raw metric recording throughput under +// concurrent writes. This is the critical path for every driver operation. +func BenchmarkAggregator_RecordMetric(b *testing.B) { + cfg := DefaultConfig() + cfg.BatchSize = 10000 + cfg.FlushInterval = 10 * time.Minute + + exporter := newTelemetryExporter("localhost", "test-version", &http.Client{}, cfg) + agg := newMetricsAggregator(exporter, cfg) + defer agg.close(context.Background()) //nolint:errcheck + + ctx := context.Background() + metric := &telemetryMetric{ + metricType: "operation", + timestamp: time.Now(), + sessionID: "bench-session", + statementID: "bench-stmt", + latencyMs: 10, + tags: map[string]interface{}{"operation_type": OperationTypeExecuteStatement}, + } + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + agg.recordMetric(ctx, metric) + } +} + +// BenchmarkExporter_Export measures export throughput against a real HTTP server. +func BenchmarkExporter_Export(b *testing.B) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + cfg := DefaultConfig() + cfg.MaxRetries = 0 + httpClient := &http.Client{Timeout: 5 * time.Second} + exporter := newTelemetryExporter(server.URL, "test-version", httpClient, cfg) + + metrics := make([]*telemetryMetric, 10) + for i := range metrics { + metrics[i] = &telemetryMetric{ + metricType: "operation", + timestamp: time.Now(), + sessionID: "bench-session", + latencyMs: int64(i), + tags: map[string]interface{}{}, + } + } + + ctx := context.Background() + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + exporter.export(ctx, metrics) + } +} + +// BenchmarkConcurrentConnections_PerHostSharing measures per-host client sharing +// under concurrent access. Multiple goroutines contend for the same host entry, +// simulating many connections to one warehouse. +func BenchmarkConcurrentConnections_PerHostSharing(b *testing.B) { + manager := &clientManager{ + clients: make(map[string]*clientHolder), + } + + host := "bench-host.databricks.com" + httpClient := &http.Client{} + cfg := DefaultConfig() + cfg.FlushInterval = 10 * time.Minute + + b.ResetTimer() + b.ReportAllocs() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + client := manager.getOrCreateClient(host, "test-version", httpClient, cfg) + if client != nil { + _ = manager.releaseClient(host) + } + } + }) +} + +// BenchmarkCircuitBreaker_Execute measures the circuit breaker hot-path overhead +// with the circuit closed (normal operation). +func BenchmarkCircuitBreaker_Execute(b *testing.B) { + cb := newCircuitBreaker(defaultCircuitBreakerConfig()) + ctx := context.Background() + fn := func() error { return nil } + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _ = cb.execute(ctx, fn) + } +} + +// TestLoadTesting_ConcurrentConnections verifies correctness of the aggregator +// and client manager under high concurrency: 100 goroutines, 50 ops each. +func TestLoadTesting_ConcurrentConnections(t *testing.T) { + if testing.Short() { + t.Skip("skipping load test in short mode") + } + + manager := &clientManager{ + clients: make(map[string]*clientHolder), + } + + host := "load-test-host.databricks.com" + httpClient := &http.Client{} + cfg := DefaultConfig() + cfg.FlushInterval = 10 * time.Minute + + const goroutines = 100 + const opsPerGoroutine = 50 + + var wg sync.WaitGroup + var errors int64 + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + client := manager.getOrCreateClient(host, "test-version", httpClient, cfg) + if client == nil { + atomic.AddInt64(&errors, 1) + return + } + + interceptor := client.GetInterceptor(true) + ctx := context.Background() + + for j := 0; j < opsPerGoroutine; j++ { + ctx2 := interceptor.BeforeExecute(ctx, "session", fmt.Sprintf("stmt-%d", j)) + interceptor.AfterExecute(ctx2, nil) + } + + _ = manager.releaseClient(host) + }() + } + + wg.Wait() + + if atomic.LoadInt64(&errors) > 0 { + t.Errorf("got %d errors during concurrent load test", atomic.LoadInt64(&errors)) + } + + // All connections released — client should be removed from map + manager.mu.RLock() + remaining := len(manager.clients) + manager.mu.RUnlock() + if remaining != 0 { + t.Errorf("expected 0 remaining clients after all releases, got %d", remaining) + } +} + +// TestGracefulShutdown_ReferenceCountingCleanup verifies that clients are cleaned up +// only when the last connection for a host is released (reference counting is correct). +func TestGracefulShutdown_ReferenceCountingCleanup(t *testing.T) { + manager := &clientManager{ + clients: make(map[string]*clientHolder), + } + + hosts := []string{ + "host-a.databricks.com", + "host-b.databricks.com", + "host-c.databricks.com", + } + httpClient := &http.Client{} + cfg := DefaultConfig() + cfg.FlushInterval = 10 * time.Minute + + // Open 3 connections per host + for _, host := range hosts { + for i := 0; i < 3; i++ { + if client := manager.getOrCreateClient(host, "test-version", httpClient, cfg); client == nil { + t.Fatalf("expected client for host %s", host) + } + } + } + if len(manager.clients) != len(hosts) { + t.Fatalf("expected %d clients, got %d", len(hosts), len(manager.clients)) + } + + // Release 2 of 3 connections per host — clients must remain + for _, host := range hosts { + for i := 0; i < 2; i++ { + if err := manager.releaseClient(host); err != nil { + t.Errorf("unexpected error releasing client for %s: %v", host, err) + } + } + } + if len(manager.clients) != len(hosts) { + t.Errorf("expected clients to remain with 1 ref each, got %d clients", len(manager.clients)) + } + + // Release the last connection per host — all clients should be cleaned up + for _, host := range hosts { + if err := manager.releaseClient(host); err != nil { + t.Errorf("unexpected error on final release for %s: %v", host, err) + } + } + + manager.mu.RLock() + remaining := len(manager.clients) + manager.mu.RUnlock() + if remaining != 0 { + t.Errorf("expected 0 clients after all releases, got %d", remaining) + } +} + +// TestGracefulShutdown_FinalFlush verifies that metrics buffered since the last periodic +// flush are exported synchronously when the aggregator is closed, not silently dropped. +func TestGracefulShutdown_FinalFlush(t *testing.T) { + var flushed int64 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&flushed, 1) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + cfg := DefaultConfig() + cfg.BatchSize = 100 + cfg.FlushInterval = 10 * time.Minute // prevent auto-flush + httpClient := &http.Client{Timeout: 5 * time.Second} + + exporter := newTelemetryExporter(server.URL, "test-version", httpClient, cfg) + agg := newMetricsAggregator(exporter, cfg) + + ctx := context.Background() + + // Record a few metrics without filling the batch (no auto-flush triggers) + for i := 0; i < 5; i++ { + agg.recordMetric(ctx, &telemetryMetric{ + metricType: "operation", + timestamp: time.Now(), + sessionID: "test-session", + latencyMs: int64(i), + tags: map[string]interface{}{"operation_type": OperationTypeExecuteStatement}, + }) + } + + // Close must flush remaining metrics synchronously before returning + if err := agg.close(ctx); err != nil { + t.Errorf("expected no error on close, got %v", err) + } + + if atomic.LoadInt64(&flushed) == 0 { + t.Error("expected metrics to be flushed synchronously on close, but no HTTP request was made") + } +} diff --git a/telemetry/client.go b/telemetry/client.go index 38ed6972..25c68e88 100644 --- a/telemetry/client.go +++ b/telemetry/client.go @@ -31,9 +31,9 @@ type telemetryClient struct { } // newTelemetryClient creates a new telemetry client for the given host. -func newTelemetryClient(host string, httpClient *http.Client, cfg *Config) *telemetryClient { +func newTelemetryClient(host string, driverVersion string, httpClient *http.Client, cfg *Config) *telemetryClient { // Create exporter - exporter := newTelemetryExporter(host, httpClient, cfg) + exporter := newTelemetryExporter(host, driverVersion, httpClient, cfg) // Create aggregator with exporter aggregator := newMetricsAggregator(exporter, cfg) diff --git a/telemetry/config.go b/telemetry/config.go index 7bc76d00..1238333f 100644 --- a/telemetry/config.go +++ b/telemetry/config.go @@ -5,8 +5,6 @@ import ( "net/http" "strconv" "time" - - "github.com/databricks/databricks-sql-go/internal/config" ) // Config holds telemetry configuration. @@ -14,12 +12,9 @@ type Config struct { // Enabled controls whether telemetry is active Enabled bool - // EnableTelemetry is the client-side telemetry preference. - // Uses config overlay pattern: client > server > default - // - Unset: use server feature flag (default behavior) - // - Set to true: client wants telemetry enabled (overrides server) - // - Set to false: client wants telemetry disabled (overrides server) - EnableTelemetry config.ConfigValue[bool] + // EnableTelemetry indicates user wants telemetry enabled. + // Follows client > server > default priority. + EnableTelemetry bool // BatchSize is the number of metrics to batch before flushing BatchSize int @@ -44,12 +39,12 @@ type Config struct { } // DefaultConfig returns default telemetry configuration. -// Note: Telemetry uses config overlay - controlled by server feature flags by default. -// Clients can override by explicitly setting enableTelemetry=true/false. +// Note: Telemetry is disabled by default. The default will remain false until +// server-side feature flags are wired in to control the rollout. func DefaultConfig() *Config { return &Config{ - Enabled: false, // Will be set based on overlay logic - EnableTelemetry: config.ConfigValue[bool]{}, // Unset = use server feature flag + Enabled: false, + EnableTelemetry: false, BatchSize: 100, FlushInterval: 5 * time.Second, MaxRetries: 3, @@ -64,12 +59,11 @@ func DefaultConfig() *Config { func ParseTelemetryConfig(params map[string]string) *Config { cfg := DefaultConfig() - // Config overlay approach: client setting overrides server feature flag - // Priority: - // 1. Client explicit setting (enableTelemetry=true/false) - overrides server - // 2. Server feature flag (when client doesn't set) - server controls - // 3. Default disabled (when server flag unavailable) - fail-safe - cfg.EnableTelemetry = config.ParseBoolConfigValue(params, "enableTelemetry") + if v, ok := params["enableTelemetry"]; ok { + if b, err := strconv.ParseBool(v); err == nil { + cfg.EnableTelemetry = b + } + } if v, ok := params["telemetry_batch_size"]; ok { if size, err := strconv.Atoi(v); err == nil && size > 0 { @@ -87,12 +81,13 @@ func ParseTelemetryConfig(params map[string]string) *Config { } // isTelemetryEnabled checks if telemetry should be enabled for this connection. -// Implements config overlay approach with clear priority order. +// Implements the priority-based decision tree for telemetry enablement. // -// Config Overlay Priority (highest to lowest): -// 1. Client Config - enableTelemetry explicitly set (true/false) - overrides server -// 2. Server Config - feature flag controls when client doesn't specify -// 3. Fail-Safe Default - disabled when server flag unavailable/errors +// Priority (highest to lowest): +// 1. enableTelemetry=true - Client opt-in (server feature flag still consulted) +// 2. enableTelemetry=false - Explicit opt-out (always disabled) +// 3. Server Feature Flag Only - Default behavior (Databricks-controlled) +// 4. Default - Disabled (false) // // Parameters: // - ctx: Context for the request @@ -102,18 +97,23 @@ func ParseTelemetryConfig(params map[string]string) *Config { // // Returns: // - bool: true if telemetry should be enabled, false otherwise -func isTelemetryEnabled(ctx context.Context, cfg *Config, host string, httpClient *http.Client) bool { - // Priority 1: Client explicitly set (overrides server) - if cfg.EnableTelemetry.IsSet() { - val, _ := cfg.EnableTelemetry.Get() - return val +func isTelemetryEnabled(ctx context.Context, cfg *Config, host string, driverVersion string, httpClient *http.Client) bool { + // Priority 1 & 2: Respect client preference when explicitly set + // enableTelemetry=false → always disabled; enableTelemetry=true → check server flag + // When enableTelemetry is explicitly set to false, respect that + if !cfg.EnableTelemetry { + return false } - // Priority 2: Check server-side feature flag + // Priority 3 & 4: Check server-side feature flag + // This handles both: + // - User explicitly opted in (enableTelemetry=true) - respect server decision + // - Default behavior (no explicit setting) - server controls enablement flagCache := getFeatureFlagCache() - serverEnabled, err := flagCache.isTelemetryEnabled(ctx, host, httpClient) + serverEnabled, err := flagCache.isTelemetryEnabled(ctx, host, driverVersion, httpClient) if err != nil { - // Priority 3: Fail-safe default (disabled) + // On error, respect default (disabled) + // This ensures telemetry failures don't impact driver operation return false } diff --git a/telemetry/config_test.go b/telemetry/config_test.go index d5ecdc2b..55f24de3 100644 --- a/telemetry/config_test.go +++ b/telemetry/config_test.go @@ -2,26 +2,18 @@ package telemetry import ( "context" - "encoding/json" "net/http" "net/http/httptest" "testing" "time" - - "github.com/databricks/databricks-sql-go/internal/config" ) func TestDefaultConfig(t *testing.T) { cfg := DefaultConfig() - // Verify telemetry uses config overlay (nil = use server flag) + // Verify telemetry is disabled by default if cfg.Enabled { - t.Error("Expected Enabled to be false by default") - } - - // Verify EnableTelemetry is unset (config overlay - use server flag) - if cfg.EnableTelemetry.IsSet() { - t.Error("Expected EnableTelemetry to be unset (use server flag), got set") + t.Error("Expected telemetry to be disabled by default, got enabled") } // Verify other defaults @@ -58,9 +50,9 @@ func TestParseTelemetryConfig_EmptyParams(t *testing.T) { params := map[string]string{} cfg := ParseTelemetryConfig(params) - // Should return defaults - EnableTelemetry unset means use server flag - if cfg.EnableTelemetry.IsSet() { - t.Error("Expected EnableTelemetry to be unset (use server flag) when no params provided") + // Should return defaults + if cfg.Enabled { + t.Error("Expected telemetry to be disabled by default") } if cfg.BatchSize != 100 { @@ -74,8 +66,7 @@ func TestParseTelemetryConfig_EnabledTrue(t *testing.T) { } cfg := ParseTelemetryConfig(params) - val, ok := cfg.EnableTelemetry.Get() - if !ok || !val { + if !cfg.EnableTelemetry { t.Error("Expected EnableTelemetry to be true when set to 'true'") } } @@ -86,8 +77,7 @@ func TestParseTelemetryConfig_Enabled1(t *testing.T) { } cfg := ParseTelemetryConfig(params) - val, ok := cfg.EnableTelemetry.Get() - if !ok || !val { + if !cfg.EnableTelemetry { t.Error("Expected EnableTelemetry to be true when set to '1'") } } @@ -98,8 +88,7 @@ func TestParseTelemetryConfig_EnabledFalse(t *testing.T) { } cfg := ParseTelemetryConfig(params) - val, ok := cfg.EnableTelemetry.Get() - if !ok || val { + if cfg.EnableTelemetry { t.Error("Expected EnableTelemetry to be false when set to 'false'") } } @@ -182,8 +171,7 @@ func TestParseTelemetryConfig_MultipleParams(t *testing.T) { } cfg := ParseTelemetryConfig(params) - val, ok := cfg.EnableTelemetry.Get() - if !ok || !val { + if !cfg.EnableTelemetry { t.Error("Expected EnableTelemetry to be true") } @@ -201,56 +189,41 @@ func TestParseTelemetryConfig_MultipleParams(t *testing.T) { } } -// TestIsTelemetryEnabled_ClientOverrideEnabled tests Priority 1: client explicitly enables (overrides server) -func TestIsTelemetryEnabled_ClientOverrideEnabled(t *testing.T) { - // Setup: Create a server that returns disabled +// TestIsTelemetryEnabled_ExplicitOptOut tests Priority 1 (client opt-out): enableTelemetry=false +func TestIsTelemetryEnabled_ExplicitOptOut(t *testing.T) { + // Setup: Create a server that returns enabled server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Server says disabled, but client override should win - resp := map[string]interface{}{ - "flags": map[string]bool{ - "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver": false, - }, - } - _ = json.NewEncoder(w).Encode(resp) + // Even if server says enabled, explicit opt-out should disable + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"flags": [{"name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver", "value": "true"}], "ttl_seconds": 300}`)) })) defer server.Close() cfg := &Config{ - EnableTelemetry: config.NewConfigValue(true), // Priority 1: Client explicitly enables + EnableTelemetry: false, // Priority 2: Explicit opt-out } ctx := context.Background() httpClient := &http.Client{Timeout: 5 * time.Second} - // Setup feature flag cache context - flagCache := getFeatureFlagCache() - flagCache.getOrCreateContext(server.URL) - defer flagCache.releaseContext(server.URL) - - // Client override should bypass server check - result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient) + result := isTelemetryEnabled(ctx, cfg, server.URL, "test-version", httpClient) - if !result { - t.Error("Expected telemetry to be enabled when client explicitly sets enableTelemetry=true, got disabled") + if result { + t.Error("Expected telemetry to be disabled with EnableTelemetry=false, got enabled") } } -// TestIsTelemetryEnabled_ClientOverrideDisabled tests Priority 1: client explicitly disables (overrides server) -func TestIsTelemetryEnabled_ClientOverrideDisabled(t *testing.T) { +// TestIsTelemetryEnabled_UserOptInServerEnabled tests Priority 1 (client opt-in): user opts in + server enabled +func TestIsTelemetryEnabled_UserOptInServerEnabled(t *testing.T) { // Setup: Create a server that returns enabled server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Server says enabled, but client override should win - resp := map[string]interface{}{ - "flags": map[string]bool{ - "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver": true, - }, - } - _ = json.NewEncoder(w).Encode(resp) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"flags": [{"name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver", "value": "true"}], "ttl_seconds": 300}`)) })) defer server.Close() cfg := &Config{ - EnableTelemetry: config.NewConfigValue(false), // Priority 1: Client explicitly disables + EnableTelemetry: true, // User wants telemetry } ctx := context.Background() @@ -261,28 +234,24 @@ func TestIsTelemetryEnabled_ClientOverrideDisabled(t *testing.T) { flagCache.getOrCreateContext(server.URL) defer flagCache.releaseContext(server.URL) - result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient) + result := isTelemetryEnabled(ctx, cfg, server.URL, "test-version", httpClient) - if result { - t.Error("Expected telemetry to be disabled when client explicitly sets enableTelemetry=false, got enabled") + if !result { + t.Error("Expected telemetry to be enabled when user opts in and server allows, got disabled") } } -// TestIsTelemetryEnabled_ServerEnabled tests Priority 2: server flag enables (client didn't set) -func TestIsTelemetryEnabled_ServerEnabled(t *testing.T) { - // Setup: Create a server that returns enabled +// TestIsTelemetryEnabled_UserOptInServerDisabled tests: user opts in but server disabled +func TestIsTelemetryEnabled_UserOptInServerDisabled(t *testing.T) { + // Setup: Create a server that returns disabled server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - resp := map[string]interface{}{ - "flags": map[string]bool{ - "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver": true, - }, - } - _ = json.NewEncoder(w).Encode(resp) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"flags": [{"name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver", "value": "false"}], "ttl_seconds": 300}`)) })) defer server.Close() cfg := &Config{ - EnableTelemetry: config.ConfigValue[bool]{}, // Client didn't set - use server flag + EnableTelemetry: true, // User wants telemetry } ctx := context.Background() @@ -293,28 +262,24 @@ func TestIsTelemetryEnabled_ServerEnabled(t *testing.T) { flagCache.getOrCreateContext(server.URL) defer flagCache.releaseContext(server.URL) - result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient) + result := isTelemetryEnabled(ctx, cfg, server.URL, "test-version", httpClient) - if !result { - t.Error("Expected telemetry to be enabled when server flag is true, got disabled") + if result { + t.Error("Expected telemetry to be disabled when server disables it, got enabled") } } -// TestIsTelemetryEnabled_ServerDisabled tests Priority 2: server flag disables (client didn't set) -func TestIsTelemetryEnabled_ServerDisabled(t *testing.T) { - // Setup: Create a server that returns disabled +// TestIsTelemetryEnabled_ServerFlagOnly tests: default EnableTelemetry=false is always disabled +func TestIsTelemetryEnabled_ServerFlagOnly(t *testing.T) { + // Setup: Create a server that returns enabled server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - resp := map[string]interface{}{ - "flags": map[string]bool{ - "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver": false, - }, - } - _ = json.NewEncoder(w).Encode(resp) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"flags": [{"name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver", "value": "true"}], "ttl_seconds": 300}`)) })) defer server.Close() cfg := &Config{ - EnableTelemetry: config.ConfigValue[bool]{}, // Client didn't set - use server flag + EnableTelemetry: false, // Default: no explicit user preference } ctx := context.Background() @@ -325,29 +290,29 @@ func TestIsTelemetryEnabled_ServerDisabled(t *testing.T) { flagCache.getOrCreateContext(server.URL) defer flagCache.releaseContext(server.URL) - result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient) + result := isTelemetryEnabled(ctx, cfg, server.URL, "test-version", httpClient) + // When enableTelemetry is false (default), should return false (Priority 2) if result { - t.Error("Expected telemetry to be disabled when server flag is false, got enabled") + t.Error("Expected telemetry to be disabled with default EnableTelemetry=false, got enabled") } } -// TestIsTelemetryEnabled_FailSafeDefault tests Priority 3: default disabled when server unavailable -func TestIsTelemetryEnabled_FailSafeDefault(t *testing.T) { +// TestIsTelemetryEnabled_Default tests Priority 5: default disabled +func TestIsTelemetryEnabled_Default(t *testing.T) { cfg := DefaultConfig() ctx := context.Background() httpClient := &http.Client{Timeout: 5 * time.Second} - // No server available, should default to disabled (fail-safe) - result := isTelemetryEnabled(ctx, cfg, "nonexistent-host", httpClient) + result := isTelemetryEnabled(ctx, cfg, "test-host", "test-version", httpClient) if result { - t.Error("Expected telemetry to be disabled when server unavailable (fail-safe), got enabled") + t.Error("Expected telemetry to be disabled by default, got enabled") } } -// TestIsTelemetryEnabled_ServerError tests Priority 3: fail-safe default on server error +// TestIsTelemetryEnabled_ServerError tests error handling func TestIsTelemetryEnabled_ServerError(t *testing.T) { // Setup: Create a server that returns error server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -356,7 +321,7 @@ func TestIsTelemetryEnabled_ServerError(t *testing.T) { defer server.Close() cfg := &Config{ - EnableTelemetry: config.ConfigValue[bool]{}, // Client didn't set - should use server, but server errors + EnableTelemetry: true, // User wants telemetry } ctx := context.Background() @@ -367,18 +332,18 @@ func TestIsTelemetryEnabled_ServerError(t *testing.T) { flagCache.getOrCreateContext(server.URL) defer flagCache.releaseContext(server.URL) - result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient) + result := isTelemetryEnabled(ctx, cfg, server.URL, "test-version", httpClient) - // On error, should default to disabled (fail-safe) + // On error, should default to disabled if result { - t.Error("Expected telemetry to be disabled on server error (fail-safe), got enabled") + t.Error("Expected telemetry to be disabled on server error, got enabled") } } -// TestIsTelemetryEnabled_ServerUnreachable tests Priority 3: fail-safe default on unreachable server +// TestIsTelemetryEnabled_ServerUnreachable tests unreachable server func TestIsTelemetryEnabled_ServerUnreachable(t *testing.T) { cfg := &Config{ - EnableTelemetry: config.ConfigValue[bool]{}, // Client didn't set - should use server, but server unreachable + EnableTelemetry: true, // User wants telemetry } ctx := context.Background() @@ -390,38 +355,10 @@ func TestIsTelemetryEnabled_ServerUnreachable(t *testing.T) { flagCache.getOrCreateContext(unreachableHost) defer flagCache.releaseContext(unreachableHost) - result := isTelemetryEnabled(ctx, cfg, unreachableHost, httpClient) + result := isTelemetryEnabled(ctx, cfg, unreachableHost, "test-version", httpClient) - // On error, should default to disabled (fail-safe) + // On error, should default to disabled if result { - t.Error("Expected telemetry to be disabled when server unreachable (fail-safe), got enabled") - } -} - -// TestIsTelemetryEnabled_ClientOverridesServerError tests Priority 1 > Priority 3 -func TestIsTelemetryEnabled_ClientOverridesServerError(t *testing.T) { - // Setup: Create a server that returns error - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - })) - defer server.Close() - - cfg := &Config{ - EnableTelemetry: config.NewConfigValue(true), // Client explicitly enables - should override server error - } - - ctx := context.Background() - httpClient := &http.Client{Timeout: 5 * time.Second} - - // Setup feature flag cache context - flagCache := getFeatureFlagCache() - flagCache.getOrCreateContext(server.URL) - defer flagCache.releaseContext(server.URL) - - result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient) - - // Client override should work even when server errors - if !result { - t.Error("Expected telemetry to be enabled when client explicitly sets true, even with server error, got disabled") + t.Error("Expected telemetry to be disabled when server unreachable, got enabled") } } diff --git a/telemetry/driver_integration.go b/telemetry/driver_integration.go index 5dcd2f71..743b4107 100644 --- a/telemetry/driver_integration.go +++ b/telemetry/driver_integration.go @@ -3,6 +3,7 @@ package telemetry import ( "context" "net/http" + "time" "github.com/databricks/databricks-sql-go/internal/config" ) @@ -14,6 +15,7 @@ import ( // Parameters: // - ctx: Context for the initialization // - host: Databricks host +// - driverVersion: Driver version string // - httpClient: HTTP client for making requests // - enableTelemetry: Client config overlay (unset = check server flag, true/false = override server) // @@ -22,30 +24,50 @@ import ( func InitializeForConnection( ctx context.Context, host string, + driverVersion string, httpClient *http.Client, enableTelemetry config.ConfigValue[bool], + batchSize int, + flushInterval time.Duration, ) *Interceptor { - // Create telemetry config and apply client overlay + // Create telemetry config and apply client overlay. + // ConfigValue[bool] semantics: + // - unset → true (let server feature flag decide) + // - true → true (server feature flag still consulted) + // - false → false (explicitly disabled, skip server flag check) cfg := DefaultConfig() - cfg.EnableTelemetry = enableTelemetry + if val, isSet := enableTelemetry.Get(); isSet { + cfg.EnableTelemetry = val + } else { + cfg.EnableTelemetry = true // Unset: default to enabled, server flag decides + } + if batchSize > 0 { + cfg.BatchSize = batchSize + } + if flushInterval > 0 { + cfg.FlushInterval = flushInterval + } + + // Get feature flag cache context FIRST (for reference counting) + flagCache := getFeatureFlagCache() + flagCache.getOrCreateContext(host) // Check if telemetry should be enabled - if !isTelemetryEnabled(ctx, cfg, host, httpClient) { + enabled := isTelemetryEnabled(ctx, cfg, host, driverVersion, httpClient) + if !enabled { + flagCache.releaseContext(host) return nil } // Get or create telemetry client for this host clientMgr := getClientManager() - telemetryClient := clientMgr.getOrCreateClient(host, httpClient, cfg) + telemetryClient := clientMgr.getOrCreateClient(host, driverVersion, httpClient, cfg) if telemetryClient == nil { + // Client failed to start; release the flag cache ref we incremented above + flagCache.releaseContext(host) return nil } - // Get feature flag cache context (for reference counting) - flagCache := getFeatureFlagCache() - flagCache.getOrCreateContext(host) - - // Return interceptor return telemetryClient.GetInterceptor(true) } diff --git a/telemetry/errors.go b/telemetry/errors.go index edc2319e..0702dd06 100644 --- a/telemetry/errors.go +++ b/telemetry/errors.go @@ -78,14 +78,6 @@ func classifyError(err error) string { return "error" } -// isRetryableError returns true if the error is retryable. -// This is the inverse of isTerminalError. -// -//nolint:deadcode,unused // Will be used in Phase 8+ -func isRetryableError(err error) bool { - return !isTerminalError(err) -} - // httpError represents an HTTP error with status code. type httpError struct { @@ -97,16 +89,6 @@ func (e *httpError) Error() string { return e.message } -// newHTTPError creates a new HTTP error. -// -//nolint:deadcode,unused // Will be used in Phase 8+ -func newHTTPError(statusCode int, message string) error { - return &httpError{ - statusCode: statusCode, - message: message, - } -} - // isTerminalHTTPStatus returns true for non-retryable HTTP status codes. func isTerminalHTTPStatus(status int) bool { diff --git a/telemetry/exporter.go b/telemetry/exporter.go index 307fb85a..53fbc14a 100644 --- a/telemetry/exporter.go +++ b/telemetry/exporter.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "io" "net/http" "strings" "time" @@ -12,9 +13,17 @@ import ( "github.com/databricks/databricks-sql-go/logger" ) +const ( + telemetryEndpointPath = "/telemetry-ext" + httpPrefix = "http://" + httpsPrefix = "https://" + defaultScheme = "https://" +) + // telemetryExporter exports metrics to Databricks telemetry service. type telemetryExporter struct { host string + driverVersion string httpClient *http.Client circuitBreaker *circuitBreaker cfg *Config @@ -32,11 +41,6 @@ type telemetryMetric struct { tags map[string]interface{} } -// telemetryPayload is the JSON structure sent to Databricks. -type telemetryPayload struct { - Metrics []*exportedMetric `json:"metrics"` -} - // exportedMetric is a single metric in the payload. type exportedMetric struct { MetricType string `json:"metric_type"` @@ -49,10 +53,19 @@ type exportedMetric struct { Tags map[string]interface{} `json:"tags,omitempty"` } +// ensureHTTPScheme adds https:// prefix to host if no scheme is present. +func ensureHTTPScheme(host string) string { + if strings.HasPrefix(host, httpPrefix) || strings.HasPrefix(host, httpsPrefix) { + return host + } + return defaultScheme + host +} + // newTelemetryExporter creates a new exporter. -func newTelemetryExporter(host string, httpClient *http.Client, cfg *Config) *telemetryExporter { +func newTelemetryExporter(host string, driverVersion string, httpClient *http.Client, cfg *Config) *telemetryExporter { return &telemetryExporter{ host: host, + driverVersion: driverVersion, httpClient: httpClient, circuitBreaker: getCircuitBreakerManager().getCircuitBreaker(host), cfg: cfg, @@ -65,8 +78,7 @@ func (e *telemetryExporter) export(ctx context.Context, metrics []*telemetryMetr // Swallow all errors and panics defer func() { if r := recover(); r != nil { - // Intentionally swallow panic - telemetry must not impact driver - logger.Debug().Msgf("telemetry: export panic: %v", r) + logger.Trace().Msgf("telemetry: export panic: %v", r) } }() @@ -81,38 +93,27 @@ func (e *telemetryExporter) export(ctx context.Context, metrics []*telemetryMetr } if err != nil { - // Intentionally swallow error - telemetry must not impact driver - _ = err // Log at trace level only: logger.Trace().Msgf("telemetry: export error: %v", err) + logger.Trace().Msgf("telemetry: export error: %v", err) } } // doExport performs the actual export with retries and exponential backoff. func (e *telemetryExporter) doExport(ctx context.Context, metrics []*telemetryMetric) error { - // Convert metrics to exported format with tag filtering - exportedMetrics := make([]*exportedMetric, 0, len(metrics)) - for _, m := range metrics { - exportedMetrics = append(exportedMetrics, m.toExportedMetric()) - } - - // Create payload - payload := &telemetryPayload{ - Metrics: exportedMetrics, + // Create telemetry request with base64-encoded logs + request, err := createTelemetryRequest(metrics, e.driverVersion) + if err != nil { + return fmt.Errorf("failed to create telemetry request: %w", err) } - // Serialize metrics - data, err := json.Marshal(payload) + // Serialize request + data, err := json.Marshal(request) if err != nil { - return fmt.Errorf("failed to marshal metrics: %w", err) + return fmt.Errorf("failed to marshal request: %w", err) } // Determine endpoint - // Support both plain hosts and full URLs (for testing) - var endpoint string - if strings.HasPrefix(e.host, "http://") || strings.HasPrefix(e.host, "https://") { - endpoint = fmt.Sprintf("%s/telemetry-ext", e.host) - } else { - endpoint = fmt.Sprintf("https://%s/telemetry-ext", e.host) - } + hostURL := ensureHTTPScheme(e.host) + endpoint := hostURL + telemetryEndpointPath // Retry logic with exponential backoff maxRetries := e.cfg.MaxRetries @@ -122,7 +123,6 @@ func (e *telemetryExporter) doExport(ctx context.Context, metrics []*telemetryMe backoff := time.Duration(1<= 200 && resp.StatusCode < 300 { diff --git a/telemetry/exporter_test.go b/telemetry/exporter_test.go index bb772453..864f0c7e 100644 --- a/telemetry/exporter_test.go +++ b/telemetry/exporter_test.go @@ -16,7 +16,7 @@ func TestNewTelemetryExporter(t *testing.T) { httpClient := &http.Client{Timeout: 5 * time.Second} host := "test-host" - exporter := newTelemetryExporter(host, httpClient, cfg) + exporter := newTelemetryExporter(host, "test-version", httpClient, cfg) if exporter.host != host { t.Errorf("Expected host %s, got %s", host, exporter.host) @@ -54,15 +54,15 @@ func TestExport_Success(t *testing.T) { t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type")) } - // Verify payload structure + // Verify payload structure (new TelemetryRequest format) body, _ := io.ReadAll(r.Body) - var payload telemetryPayload + var payload TelemetryRequest if err := json.Unmarshal(body, &payload); err != nil { t.Errorf("Failed to unmarshal payload: %v", err) } - if len(payload.Metrics) != 1 { - t.Errorf("Expected 1 metric, got %d", len(payload.Metrics)) + if len(payload.ProtoLogs) != 1 { + t.Errorf("Expected 1 proto log, got %d", len(payload.ProtoLogs)) } w.WriteHeader(http.StatusOK) @@ -73,7 +73,7 @@ func TestExport_Success(t *testing.T) { httpClient := &http.Client{Timeout: 5 * time.Second} // Use full server URL for testing - exporter := newTelemetryExporter(server.URL, httpClient, cfg) + exporter := newTelemetryExporter(server.URL, "test-version", httpClient, cfg) metrics := []*telemetryMetric{ { @@ -113,7 +113,7 @@ func TestExport_RetryOn5xx(t *testing.T) { httpClient := &http.Client{Timeout: 5 * time.Second} // Use full server URL for testing - exporter := newTelemetryExporter(server.URL, httpClient, cfg) + exporter := newTelemetryExporter(server.URL, "test-version", httpClient, cfg) metrics := []*telemetryMetric{ { @@ -145,7 +145,7 @@ func TestExport_NonRetryable4xx(t *testing.T) { httpClient := &http.Client{Timeout: 5 * time.Second} // Use full server URL for testing - exporter := newTelemetryExporter(server.URL, httpClient, cfg) + exporter := newTelemetryExporter(server.URL, "test-version", httpClient, cfg) metrics := []*telemetryMetric{ { @@ -181,7 +181,7 @@ func TestExport_Retry429(t *testing.T) { httpClient := &http.Client{Timeout: 5 * time.Second} // Use full server URL for testing - exporter := newTelemetryExporter(server.URL, httpClient, cfg) + exporter := newTelemetryExporter(server.URL, "test-version", httpClient, cfg) metrics := []*telemetryMetric{ { @@ -211,7 +211,7 @@ func TestExport_CircuitBreakerOpen(t *testing.T) { httpClient := &http.Client{Timeout: 5 * time.Second} // Use full server URL for testing - exporter := newTelemetryExporter(server.URL, httpClient, cfg) + exporter := newTelemetryExporter(server.URL, "test-version", httpClient, cfg) // Open the circuit breaker by recording failures cb := exporter.circuitBreaker @@ -334,7 +334,7 @@ func TestExport_ErrorSwallowing(t *testing.T) { httpClient := &http.Client{Timeout: 5 * time.Second} // Use full server URL for testing - exporter := newTelemetryExporter(server.URL, httpClient, cfg) + exporter := newTelemetryExporter(server.URL, "test-version", httpClient, cfg) metrics := []*telemetryMetric{ { @@ -370,7 +370,7 @@ func TestExport_ContextCancellation(t *testing.T) { httpClient := &http.Client{Timeout: 5 * time.Second} // Use full server URL for testing - exporter := newTelemetryExporter(server.URL, httpClient, cfg) + exporter := newTelemetryExporter(server.URL, "test-version", httpClient, cfg) metrics := []*telemetryMetric{ { @@ -403,7 +403,7 @@ func TestExport_ExponentialBackoff(t *testing.T) { httpClient := &http.Client{Timeout: 5 * time.Second} // Use full server URL for testing - exporter := newTelemetryExporter(server.URL, httpClient, cfg) + exporter := newTelemetryExporter(server.URL, "test-version", httpClient, cfg) metrics := []*telemetryMetric{ { diff --git a/telemetry/featureflag.go b/telemetry/featureflag.go index 6943e455..81696baa 100644 --- a/telemetry/featureflag.go +++ b/telemetry/featureflag.go @@ -6,7 +6,6 @@ import ( "fmt" "io" "net/http" - "strings" "sync" "time" ) @@ -16,12 +15,10 @@ const ( featureFlagCacheDuration = 15 * time.Minute // featureFlagHTTPTimeout is the default timeout for feature flag HTTP requests featureFlagHTTPTimeout = 10 * time.Second - - // Feature flag names - // flagEnableTelemetry controls whether telemetry is enabled for the Go driver - flagEnableTelemetry = "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver" - // Add more feature flags here as needed: - // flagEnableNewFeature = "databricks.partnerplatform.clientConfigsFeatureFlags.enableNewFeatureForGoDriver" + // featureFlagEndpointPath is the path for feature flag endpoint + featureFlagEndpointPath = "/api/2.0/connector-service/feature-flags/GOLANG/" + // featureFlagName is the name of the Go driver telemetry feature flag + featureFlagName = "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver" ) // featureFlagCache manages feature flag state per host with reference counting. @@ -33,12 +30,12 @@ type featureFlagCache struct { // featureFlagContext holds feature flag state and reference count for a host. type featureFlagContext struct { - mu sync.RWMutex // protects flags, lastFetched, fetching - flags map[string]bool // cached feature flags by name - lastFetched time.Time // when flags were last fetched - refCount int // protected by featureFlagCache.mu - cacheDuration time.Duration // how long to cache flags - fetching bool // true if a fetch is in progress + mu sync.RWMutex // protects enabled, lastFetched, fetching + enabled *bool + lastFetched time.Time + refCount int // protected by featureFlagCache.mu + cacheDuration time.Duration + fetching bool // true if a fetch is in progress } var ( @@ -87,125 +84,89 @@ func (c *featureFlagCache) releaseContext(host string) { } } -// getFeatureFlag retrieves a specific feature flag value for the host. -// This is the generic method that handles caching and fetching for any flag. +// isTelemetryEnabled checks if telemetry is enabled for the host. // Uses cached value if available and not expired. -func (c *featureFlagCache) getFeatureFlag(ctx context.Context, host string, httpClient *http.Client, flagName string) (bool, error) { +func (c *featureFlagCache) isTelemetryEnabled(ctx context.Context, host string, driverVersion string, httpClient *http.Client) (bool, error) { c.mu.RLock() flagCtx, exists := c.contexts[host] c.mu.RUnlock() - // If context doesn't exist, create it and make initial blocking fetch if !exists { - c.mu.Lock() - // Double-check after acquiring write lock - flagCtx, exists = c.contexts[host] - if !exists { - flagCtx = &featureFlagContext{ - cacheDuration: featureFlagCacheDuration, - fetching: true, // Mark as fetching - } - c.contexts[host] = flagCtx - } - c.mu.Unlock() - - // If we just created the context, make the initial blocking fetch - if !exists { - flags, err := fetchFeatureFlags(ctx, host, httpClient) - - flagCtx.mu.Lock() - flagCtx.fetching = false - if err == nil { - flagCtx.flags = flags - flagCtx.lastFetched = time.Now() - result := flags[flagName] - flagCtx.mu.Unlock() - return result, nil - } - // On error for first fetch, fail-safe: return false (telemetry disabled) - flagCtx.mu.Unlock() - return false, nil - } + return false, nil } - // Check if cache is valid (with proper locking) + // Fast path: check cache under read lock. flagCtx.mu.RLock() - if flagCtx.flags != nil && time.Since(flagCtx.lastFetched) < flagCtx.cacheDuration { - // Cache is valid, return the cached flag value - enabled := flagCtx.flags[flagName] // returns false if flag not found + if flagCtx.enabled != nil && time.Since(flagCtx.lastFetched) < flagCtx.cacheDuration { + enabled := *flagCtx.enabled flagCtx.mu.RUnlock() return enabled, nil } - - // Check if another goroutine is already fetching if flagCtx.fetching { - // Return cached value if available, otherwise wait - if flagCtx.flags != nil { - enabled := flagCtx.flags[flagName] + if flagCtx.enabled != nil { + enabled := *flagCtx.enabled flagCtx.mu.RUnlock() return enabled, nil } flagCtx.mu.RUnlock() - // No cached value and fetch in progress, return false return false, nil } + flagCtx.mu.RUnlock() - // Mark as fetching + // Slow path: need a write lock to set fetching=true. + // Re-check all conditions under write lock (double-checked locking) to avoid + // a data race and to prevent duplicate fetches from concurrent goroutines. + flagCtx.mu.Lock() + if flagCtx.enabled != nil && time.Since(flagCtx.lastFetched) < flagCtx.cacheDuration { + enabled := *flagCtx.enabled + flagCtx.mu.Unlock() + return enabled, nil + } + if flagCtx.fetching { + if flagCtx.enabled != nil { + enabled := *flagCtx.enabled + flagCtx.mu.Unlock() + return enabled, nil + } + flagCtx.mu.Unlock() + return false, nil + } flagCtx.fetching = true - flagCtx.mu.RUnlock() + flagCtx.mu.Unlock() - // Fetch fresh values for all flags - flags, err := fetchFeatureFlags(ctx, host, httpClient) + // Fetch fresh value (outside lock so other readers are not blocked). + enabled, err := fetchFeatureFlag(ctx, host, driverVersion, httpClient) - // Update cache (with proper locking) + // Update cache. flagCtx.mu.Lock() flagCtx.fetching = false if err == nil { - flagCtx.flags = flags + flagCtx.enabled = &enabled flagCtx.lastFetched = time.Now() } - // On error, keep the old cached values if they exist result := false var returnErr error if err != nil { - if flagCtx.flags != nil { - result = flagCtx.flags[flagName] - returnErr = nil // Return cached value without error + if flagCtx.enabled != nil { + result = *flagCtx.enabled // Return stale cached value on error } else { returnErr = err } } else { - result = flags[flagName] + result = enabled } flagCtx.mu.Unlock() return result, returnErr } -// isTelemetryEnabled checks if telemetry is enabled for the host. -// Uses cached value if available and not expired. -func (c *featureFlagCache) isTelemetryEnabled(ctx context.Context, host string, httpClient *http.Client) (bool, error) { - return c.getFeatureFlag(ctx, host, httpClient, flagEnableTelemetry) -} - // isExpired returns true if the cache has expired. func (c *featureFlagContext) isExpired() bool { - return c.flags == nil || time.Since(c.lastFetched) > c.cacheDuration -} - -// getAllFeatureFlags returns a list of all feature flags to fetch. -// Add new flags here when adding new features. -func getAllFeatureFlags() []string { - return []string{ - flagEnableTelemetry, - // Add more flags here as needed: - // flagEnableNewFeature, - } + return c.enabled == nil || time.Since(c.lastFetched) > c.cacheDuration } -// fetchFeatureFlags fetches multiple feature flag values from Databricks in a single request. -// Returns a map of flag names to their boolean values. -func fetchFeatureFlags(ctx context.Context, host string, httpClient *http.Client) (map[string]bool, error) { +// fetchFeatureFlag fetches the feature flag value from Databricks. +func fetchFeatureFlag(ctx context.Context, host string, driverVersion string, httpClient *http.Client) (bool, error) { // Add timeout to context if it doesn't have a deadline if _, hasDeadline := ctx.Deadline(); !hasDeadline { var cancel context.CancelFunc @@ -213,50 +174,50 @@ func fetchFeatureFlags(ctx context.Context, host string, httpClient *http.Client defer cancel() } - // Construct endpoint URL, adding https:// if not already present - var endpoint string - if strings.HasPrefix(host, "http://") || strings.HasPrefix(host, "https://") { - endpoint = fmt.Sprintf("%s/api/2.0/feature-flags", host) - } else { - endpoint = fmt.Sprintf("https://%s/api/2.0/feature-flags", host) - } + // Construct endpoint URL using connector-service endpoint like JDBC + hostURL := ensureHTTPScheme(host) + endpoint := fmt.Sprintf("%s%s%s", hostURL, featureFlagEndpointPath, driverVersion) req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil) if err != nil { - return nil, fmt.Errorf("failed to create feature flag request: %w", err) + return false, fmt.Errorf("failed to create feature flag request: %w", err) } - // Add query parameter with comma-separated list of feature flags - // This fetches all flags in a single request for efficiency - allFlags := getAllFeatureFlags() - q := req.URL.Query() - q.Add("flags", strings.Join(allFlags, ",")) - req.URL.RawQuery = q.Encode() - resp, err := httpClient.Do(req) if err != nil { - return nil, fmt.Errorf("failed to fetch feature flags: %w", err) + return false, fmt.Errorf("failed to fetch feature flag: %w", err) } - defer resp.Body.Close() + defer resp.Body.Close() //nolint:errcheck if resp.StatusCode != http.StatusOK { // Read and discard body to allow HTTP connection reuse _, _ = io.Copy(io.Discard, resp.Body) - return nil, fmt.Errorf("feature flag check failed: %d", resp.StatusCode) + return false, fmt.Errorf("feature flag check failed: %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return false, fmt.Errorf("failed to read feature flag response: %w", err) } var result struct { - Flags map[string]bool `json:"flags"` + Flags []struct { + Name string `json:"name"` + Value string `json:"value"` + } `json:"flags"` + TTLSeconds int `json:"ttl_seconds"` } - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return nil, fmt.Errorf("failed to decode feature flag response: %w", err) + if err := json.Unmarshal(body, &result); err != nil { + return false, fmt.Errorf("failed to decode feature flag response: %w", err) } - // Return the full map of flags - // Flags not present in the response will have false value when accessed - if result.Flags == nil { - return make(map[string]bool), nil + // Look for Go driver telemetry feature flag + for _, flag := range result.Flags { + if flag.Name == featureFlagName { + enabled := flag.Value == "true" + return enabled, nil + } } - return result.Flags, nil + return false, nil } diff --git a/telemetry/featureflag_test.go b/telemetry/featureflag_test.go index b0aa519a..4ffbf07b 100644 --- a/telemetry/featureflag_test.go +++ b/telemetry/featureflag_test.go @@ -94,13 +94,12 @@ func TestFeatureFlagCache_IsTelemetryEnabled_Cached(t *testing.T) { ctx := cache.getOrCreateContext(host) // Set cached value - ctx.flags = map[string]bool{ - flagEnableTelemetry: true, - } + enabled := true + ctx.enabled = &enabled ctx.lastFetched = time.Now() // Should return cached value without HTTP call - result, err := cache.isTelemetryEnabled(context.Background(), host, nil) + result, err := cache.isTelemetryEnabled(context.Background(), host, "test-version", nil) if err != nil { t.Errorf("Expected no error, got %v", err) } @@ -116,7 +115,7 @@ func TestFeatureFlagCache_IsTelemetryEnabled_Expired(t *testing.T) { callCount++ w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"flags": {"databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver": true}}`)) + _, _ = w.Write([]byte(`{"flags": [{"name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver", "value": "true"}], "ttl_seconds": 300}`)) })) defer server.Close() @@ -128,14 +127,13 @@ func TestFeatureFlagCache_IsTelemetryEnabled_Expired(t *testing.T) { ctx := cache.getOrCreateContext(host) // Set expired cached value - ctx.flags = map[string]bool{ - flagEnableTelemetry: false, - } + enabled := false + ctx.enabled = &enabled ctx.lastFetched = time.Now().Add(-20 * time.Minute) // Expired // Should fetch fresh value httpClient := &http.Client{} - result, err := cache.isTelemetryEnabled(context.Background(), host, httpClient) + result, err := cache.isTelemetryEnabled(context.Background(), host, "test-version", httpClient) if err != nil { t.Errorf("Expected no error, got %v", err) } @@ -147,7 +145,7 @@ func TestFeatureFlagCache_IsTelemetryEnabled_Expired(t *testing.T) { } // Verify cache was updated - if ctx.flags[flagEnableTelemetry] != true { + if *ctx.enabled != true { t.Error("Expected cache to be updated with new value") } } @@ -159,15 +157,14 @@ func TestFeatureFlagCache_IsTelemetryEnabled_NoContext(t *testing.T) { host := "non-existent-host.databricks.com" - // Should return false for non-existent context (network error expected) - httpClient := &http.Client{Timeout: 1 * time.Second} - result, err := cache.isTelemetryEnabled(context.Background(), host, httpClient) - // Error expected due to network failure, but should not panic + // Should return false for non-existent context + result, err := cache.isTelemetryEnabled(context.Background(), host, "test-version", nil) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } if result != false { t.Error("Expected false for non-existent context") } - // err is expected to be non-nil due to DNS/network failure, but that's okay - _ = err } func TestFeatureFlagCache_IsTelemetryEnabled_ErrorFallback(t *testing.T) { @@ -185,14 +182,13 @@ func TestFeatureFlagCache_IsTelemetryEnabled_ErrorFallback(t *testing.T) { ctx := cache.getOrCreateContext(host) // Set cached value - ctx.flags = map[string]bool{ - flagEnableTelemetry: true, - } + enabled := true + ctx.enabled = &enabled ctx.lastFetched = time.Now().Add(-20 * time.Minute) // Expired // Should return cached value on error httpClient := &http.Client{} - result, err := cache.isTelemetryEnabled(context.Background(), host, httpClient) + result, err := cache.isTelemetryEnabled(context.Background(), host, "test-version", httpClient) if err != nil { t.Errorf("Expected no error (fallback to cache), got %v", err) } @@ -217,7 +213,7 @@ func TestFeatureFlagCache_IsTelemetryEnabled_ErrorNoCache(t *testing.T) { // No cached value, should return error httpClient := &http.Client{} - result, err := cache.isTelemetryEnabled(context.Background(), host, httpClient) + result, err := cache.isTelemetryEnabled(context.Background(), host, "test-version", httpClient) if err == nil { t.Error("Expected error when no cache available and fetch fails") } @@ -275,28 +271,28 @@ func TestFeatureFlagCache_ConcurrentAccess(t *testing.T) { func TestFeatureFlagContext_IsExpired(t *testing.T) { tests := []struct { name string - flags map[string]bool + enabled *bool fetched time.Time duration time.Duration want bool }{ { name: "no cache", - flags: nil, + enabled: nil, fetched: time.Time{}, duration: 15 * time.Minute, want: true, }, { name: "fresh cache", - flags: map[string]bool{flagEnableTelemetry: true}, + enabled: boolPtr(true), fetched: time.Now(), duration: 15 * time.Minute, want: false, }, { name: "expired cache", - flags: map[string]bool{flagEnableTelemetry: true}, + enabled: boolPtr(true), fetched: time.Now().Add(-20 * time.Minute), duration: 15 * time.Minute, want: true, @@ -306,7 +302,7 @@ func TestFeatureFlagContext_IsExpired(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &featureFlagContext{ - flags: tt.flags, + enabled: tt.enabled, lastFetched: tt.fetched, cacheDuration: tt.duration, } @@ -317,82 +313,73 @@ func TestFeatureFlagContext_IsExpired(t *testing.T) { } } -func TestFetchFeatureFlags_Success(t *testing.T) { +func TestFetchFeatureFlag_Success(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Verify request + // Verify request method if r.Method != "GET" { t.Errorf("Expected GET request, got %s", r.Method) } - if r.URL.Path != "/api/2.0/feature-flags" { - t.Errorf("Expected /api/2.0/feature-flags path, got %s", r.URL.Path) - } - - flags := r.URL.Query().Get("flags") - expectedFlag := "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver" - if flags != expectedFlag { - t.Errorf("Expected flag query param %s, got %s", expectedFlag, flags) - } - // Return success response + // Return success response using new connector-service format w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"flags": {"databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver": true}}`)) + _, _ = w.Write([]byte(`{"flags": [{"name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver", "value": "true"}], "ttl_seconds": 300}`)) })) defer server.Close() host := server.URL // Use full URL for testing httpClient := &http.Client{} - flags, err := fetchFeatureFlags(context.Background(), host, httpClient) + enabled, err := fetchFeatureFlag(context.Background(), host, "test-version", httpClient) if err != nil { t.Errorf("Expected no error, got %v", err) } - if !flags[flagEnableTelemetry] { - t.Error("Expected telemetry feature flag to be enabled") + if !enabled { + t.Error("Expected feature flag to be enabled") } } -func TestFetchFeatureFlags_Disabled(t *testing.T) { +func TestFetchFeatureFlag_Disabled(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"flags": {"databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver": false}}`)) + _, _ = w.Write([]byte(`{"flags": [{"name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver", "value": "false"}], "ttl_seconds": 300}`)) })) defer server.Close() host := server.URL // Use full URL for testing httpClient := &http.Client{} - flags, err := fetchFeatureFlags(context.Background(), host, httpClient) + enabled, err := fetchFeatureFlag(context.Background(), host, "test-version", httpClient) if err != nil { t.Errorf("Expected no error, got %v", err) } - if flags[flagEnableTelemetry] { - t.Error("Expected telemetry feature flag to be disabled") + if enabled { + t.Error("Expected feature flag to be disabled") } } -func TestFetchFeatureFlags_FlagNotPresent(t *testing.T) { +func TestFetchFeatureFlag_FlagNotPresent(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"flags": {}}`)) + _, _ = w.Write([]byte(`{"flags": [], "ttl_seconds": 300}`)) })) defer server.Close() host := server.URL // Use full URL for testing httpClient := &http.Client{} - flags, err := fetchFeatureFlags(context.Background(), host, httpClient) + enabled, err := fetchFeatureFlag(context.Background(), host, "test-version", httpClient) if err != nil { t.Errorf("Expected no error, got %v", err) } - if flags[flagEnableTelemetry] { - t.Error("Expected telemetry feature flag to be false when not present") + if enabled { + t.Error("Expected feature flag to be false when not present") } } -func TestFetchFeatureFlags_HTTPError(t *testing.T) { +func TestFetchFeatureFlag_HTTPError(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) })) @@ -401,13 +388,13 @@ func TestFetchFeatureFlags_HTTPError(t *testing.T) { host := server.URL // Use full URL for testing httpClient := &http.Client{} - _, err := fetchFeatureFlags(context.Background(), host, httpClient) + _, err := fetchFeatureFlag(context.Background(), host, "test-version", httpClient) if err == nil { t.Error("Expected error for HTTP 500") } } -func TestFetchFeatureFlags_InvalidJSON(t *testing.T) { +func TestFetchFeatureFlag_InvalidJSON(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) @@ -418,13 +405,13 @@ func TestFetchFeatureFlags_InvalidJSON(t *testing.T) { host := server.URL // Use full URL for testing httpClient := &http.Client{} - _, err := fetchFeatureFlags(context.Background(), host, httpClient) + _, err := fetchFeatureFlag(context.Background(), host, "test-version", httpClient) if err == nil { t.Error("Expected error for invalid JSON") } } -func TestFetchFeatureFlags_ContextCancellation(t *testing.T) { +func TestFetchFeatureFlag_ContextCancellation(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(100 * time.Millisecond) w.WriteHeader(http.StatusOK) @@ -437,8 +424,13 @@ func TestFetchFeatureFlags_ContextCancellation(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() // Cancel immediately - _, err := fetchFeatureFlags(ctx, host, httpClient) + _, err := fetchFeatureFlag(ctx, host, "test-version", httpClient) if err == nil { t.Error("Expected error for cancelled context") } } + +// Helper function to create bool pointer +func boolPtr(b bool) *bool { + return &b +} diff --git a/telemetry/integration_test.go b/telemetry/integration_test.go new file mode 100644 index 00000000..13a6c8b3 --- /dev/null +++ b/telemetry/integration_test.go @@ -0,0 +1,311 @@ +package telemetry + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +// TestIntegration_EndToEnd_WithCircuitBreaker tests complete end-to-end flow. +func TestIntegration_EndToEnd_WithCircuitBreaker(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + cfg := DefaultConfig() + cfg.FlushInterval = 100 * time.Millisecond + cfg.BatchSize = 5 + httpClient := &http.Client{Timeout: 5 * time.Second} + + requestCount := int32(0) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&requestCount, 1) + + // Verify request structure + if r.Method != "POST" { + t.Errorf("Expected POST, got %s", r.Method) + } + if r.URL.Path != "/telemetry-ext" { + t.Errorf("Expected /telemetry-ext, got %s", r.URL.Path) + } + + // Parse payload (new TelemetryRequest format) + body, _ := io.ReadAll(r.Body) + var payload TelemetryRequest + if err := json.Unmarshal(body, &payload); err != nil { + t.Errorf("Failed to parse payload: %v", err) + } + + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // Create telemetry client + exporter := newTelemetryExporter(server.URL, "test-version", httpClient, cfg) + aggregator := newMetricsAggregator(exporter, cfg) + defer aggregator.close(context.Background()) //nolint:errcheck + + interceptor := newInterceptor(aggregator, true) + + // Simulate statement execution + ctx := context.Background() + for i := 0; i < 10; i++ { + statementID := "stmt-integration" + ctx = interceptor.BeforeExecute(ctx, "session-id", statementID) + time.Sleep(10 * time.Millisecond) // Simulate work + interceptor.AfterExecute(ctx, nil) + interceptor.CompleteStatement(ctx, statementID, false) + } + + // Wait for flush + time.Sleep(200 * time.Millisecond) + + // Verify requests were sent + count := atomic.LoadInt32(&requestCount) + if count == 0 { + t.Error("Expected telemetry requests to be sent") + } + + t.Logf("Integration test: sent %d requests", count) +} + +// TestIntegration_CircuitBreakerOpening tests circuit breaker behavior under failures. +func TestIntegration_CircuitBreakerOpening(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + cfg := DefaultConfig() + cfg.FlushInterval = 50 * time.Millisecond + cfg.MaxRetries = 0 // No retries for faster test + httpClient := &http.Client{Timeout: 5 * time.Second} + + requestCount := int32(0) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&requestCount, 1) + // Always fail to trigger circuit breaker + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + exporter := newTelemetryExporter(server.URL, "test-version", httpClient, cfg) + aggregator := newMetricsAggregator(exporter, cfg) + defer aggregator.close(context.Background()) //nolint:errcheck + + interceptor := newInterceptor(aggregator, true) + cb := exporter.circuitBreaker + + // Send enough requests to open circuit (need 20+ calls with 50%+ failure rate) + ctx := context.Background() + for i := 0; i < 50; i++ { + statementID := "stmt-circuit" + ctx = interceptor.BeforeExecute(ctx, "session-id", statementID) + interceptor.AfterExecute(ctx, nil) + interceptor.CompleteStatement(ctx, statementID, false) + + // Small delay to ensure each batch is processed + time.Sleep(20 * time.Millisecond) + } + + // Wait for flush and circuit breaker evaluation + time.Sleep(500 * time.Millisecond) + + // Verify circuit opened (may still be closed if not enough failures recorded) + state := cb.getState() + t.Logf("Circuit breaker state after failures: %v", state) + + // Circuit should eventually open, but timing is async + // If not open, at least verify requests were attempted + initialCount := atomic.LoadInt32(&requestCount) + if initialCount == 0 { + t.Error("Expected at least some requests to be sent") + } + + // Send more requests - should be dropped if circuit is open + for i := 0; i < 10; i++ { + statementID := "stmt-dropped" + ctx = interceptor.BeforeExecute(ctx, "session-id", statementID) + interceptor.AfterExecute(ctx, nil) + interceptor.CompleteStatement(ctx, statementID, false) + } + + time.Sleep(200 * time.Millisecond) + + finalCount := atomic.LoadInt32(&requestCount) + t.Logf("Circuit breaker test: %d requests sent, state=%v", finalCount, cb.getState()) + + // Test passes if either: + // 1. Circuit opened and requests were dropped, OR + // 2. Circuit is still trying (which is also acceptable for async system) + if state == stateOpen && finalCount > initialCount+5 { + t.Errorf("Expected requests to be dropped when circuit open, got %d additional requests", finalCount-initialCount) + } +} + +// TestIntegration_OptInPriority_ExplicitOptOut tests explicit opt-out. +func TestIntegration_OptInPriority_ExplicitOptOut(t *testing.T) { + cfg := &Config{ + EnableTelemetry: false, // Priority 1 (client): Explicit opt-out + BatchSize: 100, + FlushInterval: 5 * time.Second, + MaxRetries: 3, + RetryDelay: 100 * time.Millisecond, + } + + httpClient := &http.Client{Timeout: 5 * time.Second} + + // Server that returns enabled + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := map[string]interface{}{ + "flags": map[string]bool{ + "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver": true, + }, + } + _ = json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + ctx := context.Background() + + // Should be disabled due to explicit opt-out + result := isTelemetryEnabled(ctx, cfg, server.URL, "test-version", httpClient) + + if result { + t.Error("Expected telemetry to be disabled by explicit opt-out") + } +} + +// TestIntegration_PrivacyCompliance verifies no sensitive data is collected. +func TestIntegration_PrivacyCompliance_NoQueryText(t *testing.T) { + cfg := DefaultConfig() + cfg.FlushInterval = 50 * time.Millisecond + httpClient := &http.Client{Timeout: 5 * time.Second} + + var mu sync.Mutex + var capturedBody []byte + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + mu.Lock() + capturedBody = body + mu.Unlock() + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + exporter := newTelemetryExporter(server.URL, "test-version", httpClient, cfg) + aggregator := newMetricsAggregator(exporter, cfg) + defer aggregator.close(context.Background()) //nolint:errcheck + + interceptor := newInterceptor(aggregator, true) + + ctx := context.Background() + statementID := "stmt-privacy" + ctx = interceptor.BeforeExecute(ctx, "session-id", statementID) + + // Add sensitive tags — none of these should appear in the exported telemetry + interceptor.AddTag(ctx, "query.text", "SELECT * FROM users") + interceptor.AddTag(ctx, "user.email", "user@example.com") + + interceptor.AfterExecute(ctx, nil) + interceptor.CompleteStatement(ctx, statementID, false) + + // Wait for flush + time.Sleep(200 * time.Millisecond) + + mu.Lock() + body := capturedBody + mu.Unlock() + + if len(body) == 0 { + t.Fatal("Expected telemetry request to be sent") + } + + // The exporter sends TelemetryRequest with ProtoLogs (JSON-encoded TelemetryFrontendLog). + // Verify sensitive values are absent from the serialised payload. + bodyStr := string(body) + if strings.Contains(bodyStr, "SELECT * FROM users") { + t.Error("Query text must not be exported") + } + if strings.Contains(bodyStr, "user@example.com") { + t.Error("User email must not be exported") + } + + t.Log("Privacy compliance test passed: sensitive data not present in payload") +} + +// TestIntegration_FieldMapping verifies that only known metric fields are exported +// in the TelemetryRequest format (no generic tag pass-through). +func TestIntegration_FieldMapping(t *testing.T) { + cfg := DefaultConfig() + cfg.FlushInterval = 50 * time.Millisecond + httpClient := &http.Client{Timeout: 5 * time.Second} + + var capturedRequest TelemetryRequest + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(body, &capturedRequest) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + exporter := newTelemetryExporter(server.URL, "test-version", httpClient, cfg) + + metric := &telemetryMetric{ + metricType: "connection", + timestamp: time.Now(), + workspaceID: "ws-test", + sessionID: "sess-1", + latencyMs: 42, + tags: map[string]interface{}{ + "chunk_count": 3, + "bytes_downloaded": int64(1024), + "unknown.tag": "value", // should NOT appear in output + }, + } + + ctx := context.Background() + exporter.export(ctx, []*telemetryMetric{metric}) + + time.Sleep(150 * time.Millisecond) + + if len(capturedRequest.ProtoLogs) == 0 { + t.Fatal("Expected at least one ProtoLog entry") + } + + // Each ProtoLog entry is a JSON-encoded TelemetryFrontendLog. + var log TelemetryFrontendLog + if err := json.Unmarshal([]byte(capturedRequest.ProtoLogs[0]), &log); err != nil { + t.Fatalf("Failed to unmarshal ProtoLog: %v", err) + } + + if log.Entry == nil || log.Entry.SQLDriverLog == nil { + t.Fatal("Expected SQLDriverLog to be populated") + } + + entry := log.Entry.SQLDriverLog + if entry.SessionID != "sess-1" { + t.Errorf("Expected session_id=sess-1, got %q", entry.SessionID) + } + if entry.OperationLatencyMs != 42 { + t.Errorf("Expected latency=42, got %d", entry.OperationLatencyMs) + } + if entry.SQLOperation != nil && entry.SQLOperation.ChunkDetails != nil { + if entry.SQLOperation.ChunkDetails.TotalChunksIterated != 3 { + t.Errorf("Expected total_chunks_iterated=3, got %d", entry.SQLOperation.ChunkDetails.TotalChunksIterated) + } + } + + // unknown.tag must not appear anywhere in the serialised output + if strings.Contains(capturedRequest.ProtoLogs[0], "unknown.tag") { + t.Error("unknown.tag must not be exported") + } + + t.Log("Field mapping test passed") +} diff --git a/telemetry/interceptor.go b/telemetry/interceptor.go index 419b4900..5ef01b1a 100644 --- a/telemetry/interceptor.go +++ b/telemetry/interceptor.go @@ -166,6 +166,34 @@ func (i *Interceptor) CompleteStatement(ctx context.Context, statementID string, i.aggregator.completeStatement(ctx, statementID, failed) } +// RecordOperation records an operation with type, latency, and optional error. +// Exported for use by the driver package. +func (i *Interceptor) RecordOperation(ctx context.Context, sessionID string, operationType string, latencyMs int64, err error) { + if !i.enabled { + return + } + + defer func() { + if r := recover(); r != nil { + logger.Trace().Msgf("telemetry: recordOperation panic: %v", r) + } + }() + + metric := &telemetryMetric{ + metricType: "operation", + timestamp: time.Now(), + sessionID: sessionID, + latencyMs: latencyMs, + tags: map[string]interface{}{"operation_type": operationType}, + } + + if err != nil { + metric.errorType = classifyError(err) + } + + i.aggregator.recordMetric(ctx, metric) +} + // Close flushes any pending per-connection metrics. // Does NOT close the shared aggregator — its lifecycle is managed via // ReleaseForConnection, which uses reference counting across all connections @@ -176,6 +204,12 @@ func (i *Interceptor) Close(ctx context.Context) error { return nil } - i.aggregator.flush(ctx) + defer func() { + if r := recover(); r != nil { + logger.Debug().Msgf("telemetry: Close panic: %v", r) + } + }() + + i.aggregator.flushSync(ctx) return nil } diff --git a/telemetry/manager.go b/telemetry/manager.go index 33bfe1cf..8977e924 100644 --- a/telemetry/manager.go +++ b/telemetry/manager.go @@ -45,13 +45,13 @@ func getClientManager() *clientManager { // getOrCreateClient gets or creates a telemetry client for the host. // Increments reference count. -func (m *clientManager) getOrCreateClient(host string, httpClient *http.Client, cfg *Config) *telemetryClient { +func (m *clientManager) getOrCreateClient(host string, driverVersion string, httpClient *http.Client, cfg *Config) *telemetryClient { m.mu.Lock() defer m.mu.Unlock() holder, exists := m.clients[host] if !exists { - client := newTelemetryClient(host, httpClient, cfg) + client := newTelemetryClient(host, driverVersion, httpClient, cfg) if err := client.start(); err != nil { // Failed to start client, don't add to map logger.Logger.Debug().Str("host", host).Err(err).Msg("failed to start telemetry client") diff --git a/telemetry/manager_test.go b/telemetry/manager_test.go index 59127e24..51461452 100644 --- a/telemetry/manager_test.go +++ b/telemetry/manager_test.go @@ -29,7 +29,7 @@ func TestClientManager_GetOrCreateClient(t *testing.T) { cfg := DefaultConfig() // First call should create client and increment refCount to 1 - client1 := manager.getOrCreateClient(host, httpClient, cfg) + client1 := manager.getOrCreateClient(host, "test-version", httpClient, cfg) if client1 == nil { t.Fatal("Expected client to be created") } @@ -46,7 +46,7 @@ func TestClientManager_GetOrCreateClient(t *testing.T) { } // Second call should reuse client and increment refCount to 2 - client2 := manager.getOrCreateClient(host, httpClient, cfg) + client2 := manager.getOrCreateClient(host, "test-version", httpClient, cfg) if client2 != client1 { t.Error("Expected to get the same client instance") } @@ -65,8 +65,8 @@ func TestClientManager_GetOrCreateClient_DifferentHosts(t *testing.T) { httpClient := &http.Client{} cfg := DefaultConfig() - client1 := manager.getOrCreateClient(host1, httpClient, cfg) - client2 := manager.getOrCreateClient(host2, httpClient, cfg) + client1 := manager.getOrCreateClient(host1, "test-version", httpClient, cfg) + client2 := manager.getOrCreateClient(host2, "test-version", httpClient, cfg) if client1 == client2 { t.Error("Expected different clients for different hosts") @@ -87,8 +87,8 @@ func TestClientManager_ReleaseClient(t *testing.T) { cfg := DefaultConfig() // Create client with refCount = 2 - manager.getOrCreateClient(host, httpClient, cfg) - manager.getOrCreateClient(host, httpClient, cfg) + manager.getOrCreateClient(host, "test-version", httpClient, cfg) + manager.getOrCreateClient(host, "test-version", httpClient, cfg) // First release should decrement to 1 err := manager.releaseClient(host) @@ -151,7 +151,7 @@ func TestClientManager_ConcurrentAccess(t *testing.T) { for i := 0; i < numGoroutines; i++ { go func() { defer wg.Done() - client := manager.getOrCreateClient(host, httpClient, cfg) + client := manager.getOrCreateClient(host, "test-version", httpClient, cfg) if client == nil { t.Error("Expected client to be created") } @@ -207,7 +207,7 @@ func TestClientManager_ConcurrentAccessMultipleHosts(t *testing.T) { wg.Add(1) go func(h string) { defer wg.Done() - _ = manager.getOrCreateClient(h, httpClient, cfg) + _ = manager.getOrCreateClient(h, "test-version", httpClient, cfg) }(host) } } @@ -241,7 +241,7 @@ func TestClientManager_ReleaseClientPartial(t *testing.T) { // Create 5 references for i := 0; i < 5; i++ { - manager.getOrCreateClient(host, httpClient, cfg) + manager.getOrCreateClient(host, "test-version", httpClient, cfg) } // Release 3 references @@ -271,7 +271,7 @@ func TestClientManager_ClientStartCalled(t *testing.T) { httpClient := &http.Client{} cfg := DefaultConfig() - client := manager.getOrCreateClient(host, httpClient, cfg) + client := manager.getOrCreateClient(host, "test-version", httpClient, cfg) if !client.started { t.Error("Expected start() to be called on new client") @@ -287,7 +287,7 @@ func TestClientManager_ClientCloseCalled(t *testing.T) { httpClient := &http.Client{} cfg := DefaultConfig() - client := manager.getOrCreateClient(host, httpClient, cfg) + client := manager.getOrCreateClient(host, "test-version", httpClient, cfg) _ = manager.releaseClient(host) if !client.closed { @@ -305,9 +305,9 @@ func TestClientManager_MultipleGetOrCreateSameClient(t *testing.T) { cfg := DefaultConfig() // Get same client multiple times - client1 := manager.getOrCreateClient(host, httpClient, cfg) - client2 := manager.getOrCreateClient(host, httpClient, cfg) - client3 := manager.getOrCreateClient(host, httpClient, cfg) + client1 := manager.getOrCreateClient(host, "test-version", httpClient, cfg) + client2 := manager.getOrCreateClient(host, "test-version", httpClient, cfg) + client3 := manager.getOrCreateClient(host, "test-version", httpClient, cfg) // All should be same instance if client1 != client2 || client2 != client3 { @@ -337,7 +337,7 @@ func TestClientManager_Shutdown(t *testing.T) { // Create clients for multiple hosts clients := make([]*telemetryClient, 0, len(hosts)) for _, host := range hosts { - client := manager.getOrCreateClient(host, httpClient, cfg) + client := manager.getOrCreateClient(host, "test-version", httpClient, cfg) clients = append(clients, client) } @@ -375,9 +375,9 @@ func TestClientManager_ShutdownWithActiveRefs(t *testing.T) { cfg := DefaultConfig() // Create client with multiple references - client := manager.getOrCreateClient(host, httpClient, cfg) - manager.getOrCreateClient(host, httpClient, cfg) - manager.getOrCreateClient(host, httpClient, cfg) + client := manager.getOrCreateClient(host, "test-version", httpClient, cfg) + manager.getOrCreateClient(host, "test-version", httpClient, cfg) + manager.getOrCreateClient(host, "test-version", httpClient, cfg) holder := manager.clients[host] if holder.refCount != 3 { diff --git a/telemetry/operation_type.go b/telemetry/operation_type.go new file mode 100644 index 00000000..3ba6f213 --- /dev/null +++ b/telemetry/operation_type.go @@ -0,0 +1,16 @@ +package telemetry + +const ( + OperationTypeUnspecified = "TYPE_UNSPECIFIED" + OperationTypeCreateSession = "CREATE_SESSION" + OperationTypeDeleteSession = "DELETE_SESSION" + OperationTypeExecuteStatement = "EXECUTE_STATEMENT" + OperationTypeCloseStatement = "CLOSE_STATEMENT" +) + +// isTerminalOperationType returns true for operations that signal the end of a +// session or statement lifecycle. These flush immediately so metrics are not +// lost if the connection closes before the next periodic flush — matching JDBC. +func isTerminalOperationType(opType string) bool { + return opType == OperationTypeDeleteSession || opType == OperationTypeCloseStatement +} diff --git a/telemetry/request.go b/telemetry/request.go new file mode 100644 index 00000000..f516406d --- /dev/null +++ b/telemetry/request.go @@ -0,0 +1,228 @@ +package telemetry + +import ( + "crypto/rand" + "encoding/hex" + "encoding/json" + "time" +) + +// TelemetryRequest is the top-level request sent to the telemetry endpoint. +type TelemetryRequest struct { + UploadTime int64 `json:"uploadTime"` + Items []string `json:"items"` + ProtoLogs []string `json:"protoLogs"` +} + +// TelemetryFrontendLog represents a single telemetry log entry. +type TelemetryFrontendLog struct { + WorkspaceID int64 `json:"workspace_id,omitempty"` + FrontendLogEventID string `json:"frontend_log_event_id,omitempty"` + Context *FrontendLogContext `json:"context,omitempty"` + Entry *FrontendLogEntry `json:"entry,omitempty"` +} + +// FrontendLogContext contains the client context. +type FrontendLogContext struct { + ClientContext *TelemetryClientContext `json:"client_context,omitempty"` +} + +// TelemetryClientContext contains client-level information. +type TelemetryClientContext struct { + ClientType string `json:"client_type,omitempty"` + ClientVersion string `json:"client_version,omitempty"` +} + +// FrontendLogEntry contains the actual telemetry event. +type FrontendLogEntry struct { + SQLDriverLog *TelemetryEvent `json:"sql_driver_log,omitempty"` +} + +// TelemetryEvent maps to OssSqlDriverTelemetryLog in the proto schema. +type TelemetryEvent struct { + SessionID string `json:"session_id,omitempty"` + SQLStatementID string `json:"sql_statement_id,omitempty"` + SystemConfiguration *DriverSystemConfiguration `json:"system_configuration,omitempty"` + DriverConnectionParameters *DriverConnectionParameters `json:"driver_connection_params,omitempty"` + AuthType string `json:"auth_type,omitempty"` + VolumeOperation *VolumeOperationEvent `json:"vol_operation,omitempty"` + SQLOperation *SQLExecutionEvent `json:"sql_operation,omitempty"` + ErrorInfo *DriverErrorInfo `json:"error_info,omitempty"` + OperationLatencyMs int64 `json:"operation_latency_ms,omitempty"` +} + +// DriverSystemConfiguration maps to DriverSystemConfiguration in the proto schema. +type DriverSystemConfiguration struct { + DriverVersion string `json:"driver_version,omitempty"` + RuntimeName string `json:"runtime_name,omitempty"` + RuntimeVersion string `json:"runtime_version,omitempty"` + RuntimeVendor string `json:"runtime_vendor,omitempty"` + OSName string `json:"os_name,omitempty"` + OSVersion string `json:"os_version,omitempty"` + OSArch string `json:"os_arch,omitempty"` + DriverName string `json:"driver_name,omitempty"` + ClientAppName string `json:"client_app_name,omitempty"` + LocaleName string `json:"locale_name,omitempty"` + CharSetEncoding string `json:"char_set_encoding,omitempty"` + ProcessName string `json:"process_name,omitempty"` +} + +// HostDetails maps to HostDetails in the proto schema. +type HostDetails struct { + HostURL string `json:"host_url,omitempty"` + Port int32 `json:"port,omitempty"` + ProxyAuthType string `json:"proxy_auth_type,omitempty"` +} + +// DriverConnectionParameters maps to DriverConnectionParameters in the proto schema. +// Only fields populated by the Go driver are included; others are omitted. +type DriverConnectionParameters struct { + HTTPPath string `json:"http_path,omitempty"` + Mode string `json:"mode,omitempty"` + HostInfo *HostDetails `json:"host_info,omitempty"` + UseProxy bool `json:"use_proxy,omitempty"` + AuthMech string `json:"auth_mech,omitempty"` + AuthFlow string `json:"auth_flow,omitempty"` + AuthScope string `json:"auth_scope,omitempty"` + UseSystemProxy bool `json:"use_system_proxy,omitempty"` + UseCFProxy bool `json:"use_cf_proxy,omitempty"` + EnableArrow bool `json:"enable_arrow,omitempty"` + EnableDirectResults bool `json:"enable_direct_results,omitempty"` + QueryTags string `json:"query_tags,omitempty"` + EnableMetricViewMeta bool `json:"enable_metric_view_metadata,omitempty"` + SocketTimeout int64 `json:"socket_timeout,omitempty"` +} + +// SQLExecutionEvent maps to SqlExecutionEvent in the proto schema. +type SQLExecutionEvent struct { + StatementType string `json:"statement_type,omitempty"` + IsCompressed bool `json:"is_compressed,omitempty"` + ExecutionResult string `json:"execution_result,omitempty"` + ChunkID int64 `json:"chunk_id,omitempty"` + RetryCount int64 `json:"retry_count,omitempty"` + ChunkDetails *ChunkDetails `json:"chunk_details,omitempty"` + ResultLatency *ResultLatency `json:"result_latency,omitempty"` + OperationDetail *OperationDetail `json:"operation_detail,omitempty"` + JavaUsesPatchedArrow bool `json:"java_uses_patched_arrow,omitempty"` +} + +// ChunkDetails maps to ChunkDetails in the proto schema. +type ChunkDetails struct { + InitialChunkLatencyMs int64 `json:"initial_chunk_latency_millis,omitempty"` + SlowestChunkLatencyMs int64 `json:"slowest_chunk_latency_millis,omitempty"` + TotalChunksPresent int32 `json:"total_chunks_present,omitempty"` + TotalChunksIterated int32 `json:"total_chunks_iterated,omitempty"` + SumChunksDownloadTimeMs int64 `json:"sum_chunks_download_time_millis,omitempty"` +} + +// ResultLatency maps to ResultLatency in the proto schema. +type ResultLatency struct { + ResultSetReadyLatencyMs int64 `json:"result_set_ready_latency_millis,omitempty"` + ResultSetConsumptionLatencyMs int64 `json:"result_set_consumption_latency_millis,omitempty"` +} + +// OperationDetail maps to OperationDetail in the proto schema. +type OperationDetail struct { + NOperationStatusCalls int32 `json:"n_operation_status_calls,omitempty"` + OperationStatusLatencyMs int64 `json:"operation_status_latency_millis,omitempty"` + OperationType string `json:"operation_type,omitempty"` + IsInternalCall bool `json:"is_internal_call,omitempty"` +} + +// VolumeOperationEvent maps to VolumeOperationEvent in the proto schema. +type VolumeOperationEvent struct { + VolumeOperationType string `json:"volume_operation_type,omitempty"` + VolumePath string `json:"volume_path,omitempty"` + LocalFile string `json:"local_file,omitempty"` +} + +// DriverErrorInfo maps to DriverErrorInfo in the proto schema. +type DriverErrorInfo struct { + ErrorName string `json:"error_name,omitempty"` + StackTrace string `json:"stack_trace,omitempty"` +} + +// TelemetryResponse is the response from the telemetry endpoint. +type TelemetryResponse struct { + Errors []string `json:"errors"` + NumSuccess int `json:"numSuccess"` + NumProtoSuccess int `json:"numProtoSuccess"` + NumRealtimeSuccess int `json:"numRealtimeSuccess"` +} + +// createTelemetryRequest creates a telemetry request from metrics. +func createTelemetryRequest(metrics []*telemetryMetric, driverVersion string) (*TelemetryRequest, error) { + protoLogs := make([]string, 0, len(metrics)) + + for _, metric := range metrics { + frontendLog := &TelemetryFrontendLog{ + FrontendLogEventID: generateEventID(), + Context: &FrontendLogContext{ + ClientContext: &TelemetryClientContext{ + ClientType: "golang", + ClientVersion: driverVersion, + }, + }, + Entry: &FrontendLogEntry{ + SQLDriverLog: &TelemetryEvent{ + SessionID: metric.sessionID, + SQLStatementID: metric.statementID, + SystemConfiguration: getSystemConfiguration(driverVersion), + OperationLatencyMs: metric.latencyMs, + }, + }, + } + + // Add SQL operation details if available. + if tags := metric.tags; tags != nil { + sqlOp := &SQLExecutionEvent{} + + if v, ok := tags["result.format"].(string); ok { + sqlOp.ExecutionResult = v + } + if chunkCount, ok := tags["chunk_count"].(int); ok && chunkCount > 0 { + sqlOp.ChunkDetails = &ChunkDetails{ + TotalChunksIterated: int32(chunkCount), //nolint:gosec // chunk count is always small + } + } + + if opType, ok := tags["operation_type"].(string); ok { + detail := &OperationDetail{ + OperationType: opType, + } + if pollCount, ok := tags["poll_count"].(int); ok { + detail.NOperationStatusCalls = int32(pollCount) //nolint:gosec // poll count is always small + } + sqlOp.OperationDetail = detail + } + + frontendLog.Entry.SQLDriverLog.SQLOperation = sqlOp + } + + // Add error info if present. + if metric.errorType != "" { + frontendLog.Entry.SQLDriverLog.ErrorInfo = &DriverErrorInfo{ + ErrorName: metric.errorType, + } + } + + jsonBytes, err := json.Marshal(frontendLog) + if err != nil { + return nil, err + } + protoLogs = append(protoLogs, string(jsonBytes)) + } + + return &TelemetryRequest{ + UploadTime: time.Now().UnixMilli(), + Items: []string{}, // Required but empty + ProtoLogs: protoLogs, + }, nil +} + +// generateEventID generates a unique event ID using crypto/rand. +func generateEventID() string { + b := make([]byte, 8) + _, _ = rand.Read(b) + return time.Now().Format("20060102150405") + "-" + hex.EncodeToString(b) +} diff --git a/telemetry/system_info.go b/telemetry/system_info.go new file mode 100644 index 00000000..56b979b9 --- /dev/null +++ b/telemetry/system_info.go @@ -0,0 +1,110 @@ +package telemetry + +import ( + "os" + "runtime" + "strings" + "sync" +) + +// sysInfoOnce caches the parts of system configuration that are invariant across calls +// (OS info, runtime, process name) to avoid repeated os.ReadFile on every metric. +var ( + sysInfoOnce sync.Once + cachedOSName string + cachedOSVer string + cachedArch string + cachedRuntime string + cachedLocale string + cachedProcess string +) + +func initSysInfo() { + sysInfoOnce.Do(func() { + cachedOSName = getOSName() + cachedOSVer = getOSVersion() + cachedArch = runtime.GOARCH + cachedRuntime = runtime.Version() + cachedLocale = getLocaleName() + cachedProcess = getProcessName() + }) +} + +func getSystemConfiguration(driverVersion string) *DriverSystemConfiguration { + initSysInfo() + return &DriverSystemConfiguration{ + OSName: cachedOSName, + OSVersion: cachedOSVer, + OSArch: cachedArch, + DriverName: "databricks-sql-go", + DriverVersion: driverVersion, + RuntimeName: "go", + RuntimeVersion: cachedRuntime, + RuntimeVendor: "", + LocaleName: cachedLocale, + CharSetEncoding: "UTF-8", + ProcessName: cachedProcess, + } +} + +func getOSName() string { + switch runtime.GOOS { + case "darwin": + return "macOS" + case "windows": + return "Windows" + case "linux": + return "Linux" + default: + return runtime.GOOS + } +} + +func getOSVersion() string { + switch runtime.GOOS { + case "linux": + if data, err := os.ReadFile("/etc/os-release"); err == nil { + lines := strings.Split(string(data), "\n") + for _, line := range lines { + if strings.HasPrefix(line, "VERSION=") { + version := strings.TrimPrefix(line, "VERSION=") + version = strings.Trim(version, "\"") + return version + } + } + } + if data, err := os.ReadFile("/proc/version"); err == nil { + return strings.Split(string(data), " ")[2] + } + } + return "" +} + +func getLocaleName() string { + if lang := os.Getenv("LANG"); lang != "" { + parts := strings.Split(lang, ".") + if len(parts) > 0 { + return parts[0] + } + } + return "en_US" +} + +func getProcessName() string { + if len(os.Args) > 0 { + processPath := os.Args[0] + lastSlash := strings.LastIndex(processPath, "/") + if lastSlash == -1 { + lastSlash = strings.LastIndex(processPath, "\\") + } + if lastSlash >= 0 { + processPath = processPath[lastSlash+1:] + } + dotIndex := strings.LastIndex(processPath, ".") + if dotIndex > 0 { + processPath = processPath[:dotIndex] + } + return processPath + } + return "" +}