From 94d080dd888bba6cfe08ac49b2000527be827214 Mon Sep 17 00:00:00 2001 From: JY Tan Date: Thu, 9 Apr 2026 22:56:40 -0700 Subject: [PATCH 1/8] Commit --- .gitignore | 3 + drift/core/adaptive_sampling.py | 247 +++++++++++++ drift/core/batch_processor.py | 5 + drift/core/config.py | 44 +++ drift/core/drift_sdk.py | 242 +++++++++++-- drift/core/mode_utils.py | 5 +- drift/core/no_recording.py | 22 ++ drift/core/tracing/span_utils.py | 5 + drift/instrumentation/django/middleware.py | 4 +- .../fastapi/instrumentation.py | 4 +- drift/instrumentation/wsgi/handler.py | 4 +- scripts/plot_sampling_benchmarks.py | 338 ++++++++++++++++++ tests/unit/test_adaptive_sampling.py | 42 +++ tests/unit/test_config_loading.py | 32 ++ tests/unit/test_mode_utils.py | 15 +- tests/unit/test_span_utils.py | 18 + 16 files changed, 987 insertions(+), 43 deletions(-) create mode 100644 drift/core/adaptive_sampling.py create mode 100644 drift/core/no_recording.py create mode 100755 scripts/plot_sampling_benchmarks.py create mode 100644 tests/unit/test_adaptive_sampling.py diff --git a/.gitignore b/.gitignore index 1c1425a..297301c 100644 --- a/.gitignore +++ b/.gitignore @@ -220,6 +220,9 @@ __marimo__/ **/.tusk/traces/ **/.tusk/logs/ +# Bug tracking +**/BUG_TRACKING.md + # macOS .DS_Store diff --git a/drift/core/adaptive_sampling.py b/drift/core/adaptive_sampling.py new file mode 100644 index 0000000..c287ef2 --- /dev/null +++ b/drift/core/adaptive_sampling.py @@ -0,0 +1,247 @@ +"""Adaptive sampling controller for inbound root-request admission.""" + +from __future__ import annotations + +import logging +import math +import random +import time +from dataclasses import dataclass +from typing import Literal + +logger = logging.getLogger(__name__) + +SamplingMode = Literal["fixed", "adaptive"] +AdaptiveSamplingState = Literal["fixed", "healthy", "warm", "hot", "critical_pause"] +RootSamplingDecisionReason = Literal[ + "pre_app_start", + "sampled", + "not_sampled", + "load_shed", + "critical_pause", +] + + +@dataclass +class ResolvedSamplingConfig: + mode: SamplingMode + base_rate: float + min_rate: float + + +@dataclass +class AdaptiveSamplingHealthSnapshot: + queue_fill_ratio: float | None = None + dropped_span_count: int = 0 + export_failure_count: int = 0 + export_circuit_open: bool = False + memory_pressure_ratio: float | None = None + + +@dataclass +class RootSamplingDecision: + should_record: bool + reason: RootSamplingDecisionReason + mode: SamplingMode + state: AdaptiveSamplingState + base_rate: float + min_rate: float + effective_rate: float + admission_multiplier: float + + +def _clamp(value: float, min_value: float, max_value: float) -> float: + return min(max_value, max(min_value, value)) + + +def _clamp01(value: float) -> float: + return _clamp(value, 0.0, 1.0) + + +def _normalize_between(value: float | None, zero_point: float, one_point: float) -> float: + if value is None or one_point <= zero_point: + return 0.0 + return _clamp01((value - zero_point) / (one_point - zero_point)) + + +class AdaptiveSamplingController: + def __init__( + self, + config: ResolvedSamplingConfig, + *, + random_fn=random.random, + now_fn=time.monotonic, + ) -> None: + self._config = config + self._random_fn = random_fn + self._now_fn = now_fn + + self._admission_multiplier = 1.0 + self._state: AdaptiveSamplingState = "fixed" if config.mode == "fixed" else "healthy" + self._paused_until_s = 0.0 + self._last_updated_at_s = 0.0 + self._last_decrease_at_s = 0.0 + + self._prev_dropped_span_count = 0 + self._prev_export_failure_count = 0 + + self._queue_fill_ewma: float | None = None + self._recent_drop_signal = 0.0 + self._recent_failure_signal = 0.0 + + def update(self, snapshot: AdaptiveSamplingHealthSnapshot) -> None: + if self._config.mode != "adaptive": + self._state = "fixed" + self._admission_multiplier = 1.0 + return + + now_s = self._now_fn() + elapsed_s = 2.0 if self._last_updated_at_s == 0 else max(0.001, now_s - self._last_updated_at_s) + self._last_updated_at_s = now_s + + decay = math.exp(-(elapsed_s * 1000.0) / 30000.0) + self._recent_drop_signal *= decay + self._recent_failure_signal *= decay + + dropped_delta = max(0, snapshot.dropped_span_count - self._prev_dropped_span_count) + export_failure_delta = max(0, snapshot.export_failure_count - self._prev_export_failure_count) + + self._prev_dropped_span_count = snapshot.dropped_span_count + self._prev_export_failure_count = snapshot.export_failure_count + + self._recent_drop_signal += dropped_delta + self._recent_failure_signal += export_failure_delta + + if snapshot.queue_fill_ratio is not None: + queue_fill_ratio = _clamp01(snapshot.queue_fill_ratio) + self._queue_fill_ewma = ( + queue_fill_ratio + if self._queue_fill_ewma is None + else (0.25 * queue_fill_ratio) + (0.75 * self._queue_fill_ewma) + ) + + queue_pressure = _normalize_between(self._queue_fill_ewma, 0.20, 0.85) + memory_pressure = _normalize_between(snapshot.memory_pressure_ratio, 0.80, 0.92) + export_failure_pressure = _clamp01(self._recent_failure_signal / 5.0) + pressure = max(queue_pressure, memory_pressure, export_failure_pressure) + + hard_brake = ( + dropped_delta > 0 or snapshot.export_circuit_open or (snapshot.memory_pressure_ratio or 0.0) >= 0.92 + ) + + previous_state = self._state + previous_multiplier = self._admission_multiplier + + if hard_brake: + self._paused_until_s = now_s + 15.0 + self._admission_multiplier = 0.0 + self._state = "critical_pause" + self._last_decrease_at_s = now_s + self._log_transition(previous_state, previous_multiplier, pressure, snapshot) + return + + if now_s < self._paused_until_s: + self._state = "critical_pause" + self._log_transition(previous_state, previous_multiplier, pressure, snapshot) + return + + min_multiplier = self._get_min_multiplier() + if pressure >= 0.70: + self._admission_multiplier = max(min_multiplier, self._admission_multiplier * 0.4) + self._state = "hot" + self._last_decrease_at_s = now_s + elif pressure >= 0.45: + self._admission_multiplier = max(min_multiplier, self._admission_multiplier * 0.7) + self._state = "warm" + self._last_decrease_at_s = now_s + else: + if pressure <= 0.20 and (now_s - self._last_decrease_at_s) >= 10.0: + self._admission_multiplier = min(1.0, self._admission_multiplier + 0.05) + self._state = "healthy" + + self._log_transition(previous_state, previous_multiplier, pressure, snapshot) + + def get_decision(self, *, is_pre_app_start: bool) -> RootSamplingDecision: + if is_pre_app_start: + return RootSamplingDecision( + should_record=True, + reason="pre_app_start", + mode=self._config.mode, + state=self._state, + base_rate=self._config.base_rate, + min_rate=self._config.min_rate, + effective_rate=1.0, + admission_multiplier=1.0, + ) + + effective_rate = ( + self.get_effective_sampling_rate() if self._config.mode == "adaptive" else _clamp01(self._config.base_rate) + ) + + if effective_rate <= 0.0: + return RootSamplingDecision( + should_record=False, + reason="critical_pause" if self._state == "critical_pause" else "not_sampled", + mode=self._config.mode, + state=self._state, + base_rate=self._config.base_rate, + min_rate=self._config.min_rate, + effective_rate=effective_rate, + admission_multiplier=self._admission_multiplier, + ) + + should_record = self._random_fn() < effective_rate + return RootSamplingDecision( + should_record=should_record, + reason=( + "sampled" + if should_record + else "load_shed" + if self._config.mode == "adaptive" and effective_rate < self._config.base_rate + else "not_sampled" + ), + mode=self._config.mode, + state=self._state, + base_rate=self._config.base_rate, + min_rate=self._config.min_rate, + effective_rate=effective_rate, + admission_multiplier=self._admission_multiplier if self._config.mode == "adaptive" else 1.0, + ) + + def get_effective_sampling_rate(self) -> float: + if self._config.mode != "adaptive": + return _clamp01(self._config.base_rate) + if self._state == "critical_pause" and self._now_fn() < self._paused_until_s: + return 0.0 + effective_rate = self._config.base_rate * self._admission_multiplier + return _clamp( + effective_rate, + min(self._config.base_rate, self._config.min_rate), + self._config.base_rate, + ) + + def _get_min_multiplier(self) -> float: + if self._config.base_rate <= 0.0 or self._config.min_rate <= 0.0: + return 0.0 + return _clamp01(self._config.min_rate / self._config.base_rate) + + def _log_transition( + self, + previous_state: AdaptiveSamplingState, + previous_multiplier: float, + pressure: float, + snapshot: AdaptiveSamplingHealthSnapshot, + ) -> None: + if previous_state == self._state and abs(previous_multiplier - self._admission_multiplier) < 0.05: + return + + logger.info( + "Adaptive sampling updated (state=%s, multiplier=%.2f, effective_rate=%.4f, pressure=%.2f, queue_fill=%s, memory_pressure_ratio=%s, export_circuit_open=%s).", + self._state, + self._admission_multiplier, + self.get_effective_sampling_rate(), + pressure, + f"{self._queue_fill_ewma:.2f}" if self._queue_fill_ewma is not None else "n/a", + snapshot.memory_pressure_ratio if snapshot.memory_pressure_ratio is not None else "n/a", + snapshot.export_circuit_open, + ) diff --git a/drift/core/batch_processor.py b/drift/core/batch_processor.py index 13c89e9..9338657 100644 --- a/drift/core/batch_processor.py +++ b/drift/core/batch_processor.py @@ -244,3 +244,8 @@ def queue_size(self) -> int: def dropped_span_count(self) -> int: """Get the number of dropped spans.""" return self._dropped_spans + + @property + def max_queue_size(self) -> int: + """Get the configured maximum queue size.""" + return self._config.max_queue_size diff --git a/drift/core/config.py b/drift/core/config.py index 481c07d..16df032 100644 --- a/drift/core/config.py +++ b/drift/core/config.py @@ -66,11 +66,21 @@ class ComparisonConfig: ignore_fields: list[str] = field(default_factory=list) +@dataclass +class SamplingConfig: + """Configuration for fixed vs adaptive sampling.""" + + mode: str | None = None + base_rate: float | None = None + min_rate: float | None = None + + @dataclass class RecordingConfig: """Configuration for recording behavior.""" sampling_rate: float | None = None + sampling: SamplingConfig | None = None export_spans: bool | None = None enable_env_var_recording: bool | None = None enable_analytics: bool | None = None @@ -144,8 +154,42 @@ def _parse_recording_config(data: dict[str, Any]) -> RecordingConfig: ) sampling_rate = None + sampling = None + raw_sampling = data.get("sampling") + if isinstance(raw_sampling, dict): + base_rate = raw_sampling.get("base_rate") + if base_rate is not None and not isinstance(base_rate, (int, float)): + logger.warning( + f"Invalid 'sampling.base_rate' in config: expected number, got {type(base_rate).__name__}. " + "This value will be ignored." + ) + base_rate = None + + min_rate = raw_sampling.get("min_rate") + if min_rate is not None and not isinstance(min_rate, (int, float)): + logger.warning( + f"Invalid 'sampling.min_rate' in config: expected number, got {type(min_rate).__name__}. " + "This value will be ignored." + ) + min_rate = None + + mode = raw_sampling.get("mode") + if mode is not None and not isinstance(mode, str): + logger.warning( + f"Invalid 'sampling.mode' in config: expected string, got {type(mode).__name__}. " + "This value will be ignored." + ) + mode = None + + sampling = SamplingConfig( + mode=mode, + base_rate=float(base_rate) if base_rate is not None else None, + min_rate=float(min_rate) if min_rate is not None else None, + ) + return RecordingConfig( sampling_rate=sampling_rate, + sampling=sampling, export_spans=data.get("export_spans"), enable_env_var_recording=data.get("enable_env_var_recording"), enable_analytics=data.get("enable_analytics"), diff --git a/drift/core/drift_sdk.py b/drift/core/drift_sdk.py index 395f8b9..936680f 100644 --- a/drift/core/drift_sdk.py +++ b/drift/core/drift_sdk.py @@ -8,6 +8,7 @@ import platform import random import stat +import threading import time from pathlib import Path from typing import TYPE_CHECKING, Any @@ -19,6 +20,13 @@ from ..instrumentation.registry import install_hooks from ..version import SDK_VERSION +from .adaptive_sampling import ( + AdaptiveSamplingController, + AdaptiveSamplingHealthSnapshot, + ResolvedSamplingConfig, + RootSamplingDecision, + SamplingMode, +) from .communication.communicator import CommunicatorConfig, ProtobufCommunicator from .communication.types import MockRequestInput, MockResponseOutput from .config import TuskConfig, TuskFileConfig, load_tusk_config @@ -48,6 +56,12 @@ def __init__(self) -> None: self.app_ready = False self._sdk_instance_id = self._generate_sdk_instance_id() self._sampling_rate: float = 1.0 + self._sampling_mode: str = "fixed" + self._min_sampling_rate: float = 0.0 + self._adaptive_sampling_controller: AdaptiveSamplingController | None = None + self._adaptive_sampling_thread: threading.Thread | None = None + self._adaptive_sampling_stop_event = threading.Event() + self._effective_memory_limit_bytes: int | None = None self._transform_configs: dict[str, Any] | None = None self._init_params: dict[str, Any] = {} @@ -121,14 +135,16 @@ def _log_startup_summary(self, env: str, use_remote_export: bool) -> None: ) logger.info( - "SDK initialized successfully (version=%s, mode=%s, env=%s, service=%s, serviceId=%s, exportSpans=%s, samplingRate=%s, logLevel=%s, runtime=python %s, platform=%s/%s).", + "SDK initialized successfully (version=%s, mode=%s, env=%s, service=%s, serviceId=%s, exportSpans=%s, samplingMode=%s, samplingBaseRate=%s, samplingMinRate=%s, logLevel=%s, runtime=python %s, platform=%s/%s).", SDK_VERSION, self.mode, env, service_name, service_id, use_remote_export, + self._sampling_mode, self._sampling_rate, + self._min_sampling_rate, get_log_level(), platform.python_version(), platform.system().lower(), @@ -168,7 +184,10 @@ def initialize( "log_level": log_level, } - instance._sampling_rate = instance._determine_sampling_rate(sampling_rate) + sampling_config = instance._determine_sampling_config(sampling_rate) + instance._sampling_rate = sampling_config.base_rate + instance._sampling_mode = sampling_config.mode + instance._min_sampling_rate = sampling_config.min_rate effective_api_key = api_key or os.environ.get("TUSK_API_KEY") @@ -306,6 +325,7 @@ def initialize( install_hooks() instance._init_auto_instrumentations() + instance._start_adaptive_sampling_control_loop() # Create env vars snapshot if enabled (matches Node SDK behavior) instance.create_env_vars_snapshot() @@ -318,38 +338,171 @@ def initialize( return instance - def _determine_sampling_rate(self, init_param: float | None) -> float: - """Determine the sampling rate from various sources (precedence order).""" - # 1. Init param takes precedence + def _determine_sampling_config(self, init_param: float | None) -> ResolvedSamplingConfig: + """Determine the effective sampling config from init params, env, and file config.""" + recording_config = self.file_config.recording if self.file_config else None + config_sampling = recording_config.sampling if recording_config else None + + mode: SamplingMode = "fixed" + if config_sampling and config_sampling.mode in {"fixed", "adaptive"}: + mode = "adaptive" if config_sampling.mode == "adaptive" else "fixed" + elif config_sampling and config_sampling.mode: + logger.warning( + "Invalid sampling mode from config file: %s. Must be 'fixed' or 'adaptive'. Ignoring.", + config_sampling.mode, + ) + + base_rate = 1.0 if init_param is not None: validated = validate_sampling_rate(init_param, "init params") if validated is not None: logger.debug(f"Using sampling rate from init params: {validated}") - return validated - - # 2. Environment variable - env_rate = os.environ.get("TUSK_SAMPLING_RATE") - if env_rate is not None: - try: - parsed = float(env_rate) - validated = validate_sampling_rate(parsed, "TUSK_SAMPLING_RATE env var") + base_rate = validated + else: + env_rate = os.environ.get("TUSK_SAMPLING_RATE") + if env_rate is not None: + try: + parsed = float(env_rate) + validated = validate_sampling_rate(parsed, "TUSK_SAMPLING_RATE env var") + if validated is not None: + logger.debug(f"Using sampling rate from env var: {validated}") + base_rate = validated + except ValueError: + logger.warning(f"Invalid TUSK_SAMPLING_RATE env var: {env_rate}") + elif config_sampling and config_sampling.base_rate is not None: + validated = validate_sampling_rate( + config_sampling.base_rate, "config file recording.sampling.base_rate" + ) if validated is not None: - logger.debug(f"Using sampling rate from env var: {validated}") - return validated - except ValueError: - logger.warning(f"Invalid TUSK_SAMPLING_RATE env var: {env_rate}") - - # 3. Config file - if self.file_config and self.file_config.recording and self.file_config.recording.sampling_rate is not None: - config_rate = self.file_config.recording.sampling_rate - validated = validate_sampling_rate(config_rate, "config file") - if validated is not None: - logger.debug(f"Using sampling rate from config file: {validated}") - return validated + base_rate = validated + elif recording_config and recording_config.sampling_rate is not None: + validated = validate_sampling_rate( + recording_config.sampling_rate, "config file recording.sampling_rate" + ) + if validated is not None: + base_rate = validated + else: + logger.debug("Using default sampling rate: 1.0") + + min_rate = 0.0 + if mode == "adaptive": + validated_min_rate = validate_sampling_rate( + config_sampling.min_rate if config_sampling else None, + "config file recording.sampling.min_rate", + ) + min_rate = validated_min_rate if validated_min_rate is not None else 0.001 + min_rate = min(base_rate, min_rate) + + return ResolvedSamplingConfig( + mode=mode, + base_rate=base_rate, + min_rate=min_rate, + ) + + def _determine_sampling_rate(self, init_param: float | None) -> float: + """Backward-compatible helper that returns only the effective base sampling rate.""" + return self._determine_sampling_config(init_param).base_rate + + def _start_adaptive_sampling_control_loop(self) -> None: + if self.mode != TuskDriftMode.RECORD or self._sampling_mode != "adaptive": + return + + self._adaptive_sampling_controller = AdaptiveSamplingController( + ResolvedSamplingConfig( + mode="adaptive", + base_rate=self._sampling_rate, + min_rate=self._min_sampling_rate, + ) + ) + self._effective_memory_limit_bytes = self._detect_effective_memory_limit_bytes() + self._adaptive_sampling_stop_event.clear() + + self._adaptive_sampling_thread = threading.Thread( + target=self._adaptive_sampling_loop, + daemon=True, + name="drift-adaptive-sampling", + ) + self._adaptive_sampling_thread.start() + self._update_adaptive_sampling_health() + + def _adaptive_sampling_loop(self) -> None: + while not self._adaptive_sampling_stop_event.wait(timeout=2.0): + self._update_adaptive_sampling_health() + + def _update_adaptive_sampling_health(self) -> None: + if self._adaptive_sampling_controller is None: + return + + batch_processor = self._td_span_processor._batch_processor if self._td_span_processor else None + queue_fill_ratio = None + dropped_span_count = 0 + if batch_processor is not None and batch_processor.max_queue_size > 0: + queue_fill_ratio = batch_processor.queue_size / batch_processor.max_queue_size + dropped_span_count = batch_processor.dropped_span_count + + export_failure_count = 0 + export_circuit_open = False + if self.span_exporter is not None: + for adapter in self.span_exporter.get_adapters(): + spans_failed = getattr(adapter, "spans_failed", 0) + export_failure_count += int(spans_failed) + export_circuit_open = export_circuit_open or getattr(adapter, "circuit_state", "") == "open" + + self._adaptive_sampling_controller.update( + AdaptiveSamplingHealthSnapshot( + queue_fill_ratio=queue_fill_ratio, + dropped_span_count=dropped_span_count, + export_failure_count=export_failure_count, + export_circuit_open=export_circuit_open, + memory_pressure_ratio=self._get_memory_pressure_ratio(), + ) + ) - # 4. Default - logger.debug("Using default sampling rate: 1.0") - return 1.0 + def _detect_effective_memory_limit_bytes(self) -> int | None: + candidates = ( + "/sys/fs/cgroup/memory.max", + "/sys/fs/cgroup/memory/memory.limit_in_bytes", + ) + for path in candidates: + parsed = self._read_numeric_control_file(path) + if parsed is None: + continue + if parsed <= 0 or parsed > 1_000_000_000_000_000: + continue + return parsed + return None + + def _get_memory_pressure_ratio(self) -> float | None: + if self._effective_memory_limit_bytes is None or self._effective_memory_limit_bytes <= 0: + return None + + cgroup_current = self._read_numeric_control_file("/sys/fs/cgroup/memory.current") + if cgroup_current is not None: + return cgroup_current / self._effective_memory_limit_bytes + + cgroup_v1_current = self._read_numeric_control_file("/sys/fs/cgroup/memory/memory.usage_in_bytes") + if cgroup_v1_current is not None: + return cgroup_v1_current / self._effective_memory_limit_bytes + + try: + import resource + + rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + rss_bytes = rss if platform.system() == "Darwin" else rss * 1024 + return rss_bytes / self._effective_memory_limit_bytes + except Exception: + return None + + def _read_numeric_control_file(self, path: str) -> int | None: + try: + if not os.path.exists(path): + return None + raw_value = Path(path).read_text().strip() + if not raw_value or raw_value == "max": + return None + return int(raw_value) + except Exception: + return None def _detect_mode(self) -> TuskDriftMode: """Detect the SDK mode from environment variable.""" @@ -801,6 +954,34 @@ async def send_unpatched_dependency_alert( except Exception as e: logger.debug(f"Failed to send unpatched dependency alert: {e}") + def should_record_root_request(self, *, is_pre_app_start: bool) -> RootSamplingDecision: + if self._adaptive_sampling_controller is not None: + return self._adaptive_sampling_controller.get_decision(is_pre_app_start=is_pre_app_start) + + if is_pre_app_start: + return RootSamplingDecision( + should_record=True, + reason="pre_app_start", + mode="fixed", + state="fixed", + base_rate=self._sampling_rate, + min_rate=self._min_sampling_rate, + effective_rate=1.0, + admission_multiplier=1.0, + ) + + should_record = should_sample(self._sampling_rate, True) + return RootSamplingDecision( + should_record=should_record, + reason="sampled" if should_record else "not_sampled", + mode="fixed", + state="fixed", + base_rate=self._sampling_rate, + min_rate=self._min_sampling_rate, + effective_rate=self._sampling_rate, + admission_multiplier=1.0, + ) + def get_sampling_rate(self) -> float: """Get the current sampling rate.""" return self._sampling_rate @@ -835,6 +1016,11 @@ def shutdown(self) -> None: from .coverage_server import stop_coverage_collection + self._adaptive_sampling_stop_event.set() + if self._adaptive_sampling_thread is not None: + self._adaptive_sampling_thread.join(timeout=5.0) + self._adaptive_sampling_thread = None + # Shutdown OpenTelemetry tracer provider if self._td_span_processor is not None: self._td_span_processor.shutdown() diff --git a/drift/core/mode_utils.py b/drift/core/mode_utils.py index 69d9d3a..0a4bc40 100644 --- a/drift/core/mode_utils.py +++ b/drift/core/mode_utils.py @@ -187,11 +187,10 @@ def should_record_inbound_http_request( if not is_pre_app_start: from .drift_sdk import TuskDrift - from .sampling import should_sample sdk = TuskDrift.get_instance() - sampling_rate = sdk.get_sampling_rate() - if not should_sample(sampling_rate, is_app_ready=True): + decision = sdk.should_record_root_request(is_pre_app_start=is_pre_app_start) + if not decision.should_record: return False, "not_sampled" return True, None diff --git a/drift/core/no_recording.py b/drift/core/no_recording.py new file mode 100644 index 0000000..33dc256 --- /dev/null +++ b/drift/core/no_recording.py @@ -0,0 +1,22 @@ +"""Context helpers for suppressing child span creation.""" + +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager +from contextvars import ContextVar + +_recording_suppressed: ContextVar[bool] = ContextVar("td_recording_suppressed", default=False) + + +def is_recording_suppressed() -> bool: + return _recording_suppressed.get() + + +@contextmanager +def suppress_recording() -> Iterator[None]: + token = _recording_suppressed.set(True) + try: + yield + finally: + _recording_suppressed.reset(token) diff --git a/drift/core/tracing/span_utils.py b/drift/core/tracing/span_utils.py index e246833..bc41d7e 100644 --- a/drift/core/tracing/span_utils.py +++ b/drift/core/tracing/span_utils.py @@ -19,6 +19,7 @@ from opentelemetry.trace import SpanKind as OTelSpanKind from opentelemetry.trace import Status, StatusCode +from ..no_recording import is_recording_suppressed from ..types import TuskDriftMode from .td_attributes import TdSpanAttributes @@ -135,6 +136,10 @@ def create_span(options: CreateSpanOptions) -> SpanInfo | None: Returns None if span creation fails. """ try: + if is_recording_suppressed(): + logger.debug(f"[SpanUtils] Skipping span creation for '{options.name}' - recording suppressed") + return None + # Import here to avoid circular dependency from ..drift_sdk import TuskDrift diff --git a/drift/instrumentation/django/middleware.py b/drift/instrumentation/django/middleware.py index 670679d..6514855 100644 --- a/drift/instrumentation/django/middleware.py +++ b/drift/instrumentation/django/middleware.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: from django.http import HttpRequest, HttpResponse from ...core.mode_utils import handle_record_mode, should_record_inbound_http_request +from ...core.no_recording import suppress_recording from ...core.tracing import TdSpanAttributes from ...core.tracing.span_utils import CreateSpanOptions, SpanInfo, SpanUtils from ...core.types import ( @@ -190,7 +191,8 @@ def _record_request(self, request: HttpRequest, sdk, is_pre_app_start: bool) -> ) if not should_record: logger.debug(f"[Django] Skipping request ({skip_reason}), path={path}") - return self.get_response(request) + with suppress_recording(): + return self.get_response(request) start_time_ns = time.time_ns() span_name = f"{method} {path}" diff --git a/drift/instrumentation/fastapi/instrumentation.py b/drift/instrumentation/fastapi/instrumentation.py index d280544..5e1df77 100644 --- a/drift/instrumentation/fastapi/instrumentation.py +++ b/drift/instrumentation/fastapi/instrumentation.py @@ -27,6 +27,7 @@ from ...core.drift_sdk import TuskDrift from ...core.json_schema_helper import JsonSchemaHelper, SchemaMerge from ...core.mode_utils import handle_record_mode, should_record_inbound_http_request +from ...core.no_recording import suppress_recording from ...core.tracing import TdSpanAttributes from ...core.tracing.span_utils import CreateSpanOptions, SpanInfo, SpanUtils from ...core.types import ( @@ -267,7 +268,8 @@ async def _record_request( ) if not should_record: logger.debug(f"[FastAPI] Skipping request ({skip_reason}), path={raw_path}") - return await original_call(app, scope, receive, send) + with suppress_recording(): + return await original_call(app, scope, receive, send) start_time_ns = time.time_ns() diff --git a/drift/instrumentation/wsgi/handler.py b/drift/instrumentation/wsgi/handler.py index e5f4aed..a85b0a3 100644 --- a/drift/instrumentation/wsgi/handler.py +++ b/drift/instrumentation/wsgi/handler.py @@ -31,6 +31,7 @@ from ...core.mode_utils import handle_record_mode, should_record_inbound_http_request +from ...core.no_recording import suppress_recording from ...core.tracing import TdSpanAttributes from ...core.tracing.span_utils import CreateSpanOptions, SpanUtils from ...core.types import ( @@ -225,7 +226,8 @@ def _create_and_handle_request( ) if not should_record: logger.debug(f"[WSGI] Skipping request ({skip_reason}), path={path}") - return original_wsgi_app(app, environ, start_response) + with suppress_recording(): + return original_wsgi_app(app, environ, start_response) # Capture request body request_body = capture_request_body(environ) diff --git a/scripts/plot_sampling_benchmarks.py b/scripts/plot_sampling_benchmarks.py new file mode 100755 index 0000000..5380f83 --- /dev/null +++ b/scripts/plot_sampling_benchmarks.py @@ -0,0 +1,338 @@ +#!/usr/bin/env python3 +"""Parse stack sampling benchmark logs and generate analysis artifacts.""" + +from __future__ import annotations + +import argparse +import csv +import math +import re +import statistics +from collections import defaultdict +from pathlib import Path + +BENCHMARK_LINE_RE = re.compile( + r"(Benchmark_\S+)\s+\d+\s+\d+\s+ns/op\s+([\d.]+)\s+ops/s(\s+\(~\))?" +) +LOG_NAME_RE = re.compile(r"(?P.+)_rate-(?P[0-9.]+)_run-(?P\d+)\.log$") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Build sampling-rate degradation datasets and plots from benchmark logs." + ) + parser.add_argument("--logs-dir", required=True, help="Directory containing *_rate-*_run-*.log files") + parser.add_argument("--output-dir", required=True, help="Directory to write CSV/plot artifacts") + return parser.parse_args() + + +def percentile(values: list[float], p: float) -> float: + if not values: + raise ValueError("percentile() requires non-empty values") + if len(values) == 1: + return values[0] + sorted_vals = sorted(values) + rank = (len(sorted_vals) - 1) * (p / 100.0) + lo = int(math.floor(rank)) + hi = int(math.ceil(rank)) + if lo == hi: + return sorted_vals[lo] + frac = rank - lo + return sorted_vals[lo] * (1.0 - frac) + sorted_vals[hi] * frac + + +def parse_log_file(path: Path) -> list[dict[str, object]]: + match = LOG_NAME_RE.match(path.name) + if not match: + return [] + + stack = match.group("stack") + rate = float(match.group("rate")) + run = int(match.group("run")) + + baseline: dict[str, tuple[float, bool]] = {} + sdk: dict[str, tuple[float, bool]] = {} + section: str | None = None + + with path.open("r", encoding="utf-8") as f: + for raw_line in f: + line = raw_line.strip() + + if "BASELINE (SDK DISABLED)" in line: + section = "baseline" + continue + if "WITH SDK (TUSK_DRIFT_MODE=RECORD)" in line: + section = "sdk" + continue + + m = BENCHMARK_LINE_RE.match(line) + if not m or section is None: + continue + + benchmark_name = m.group(1) + ops = float(m.group(2)) + reliable = m.group(3) is None + entry = (ops, reliable) + + if section == "baseline": + baseline[benchmark_name] = entry + elif section == "sdk": + sdk[benchmark_name] = entry + + rows: list[dict[str, object]] = [] + all_benchmarks = sorted(set(baseline.keys()) | set(sdk.keys())) + for benchmark in all_benchmarks: + base_entry = baseline.get(benchmark) + sdk_entry = sdk.get(benchmark) + + base_ops = base_entry[0] if base_entry else None + sdk_ops = sdk_entry[0] if sdk_entry else None + reliable = bool(base_entry and sdk_entry and base_entry[1] and sdk_entry[1]) + + degradation = None + if base_ops is not None and sdk_ops is not None and base_ops > 0: + # Positive means slower with SDK. + degradation = ((base_ops - sdk_ops) / base_ops) * 100.0 + + rows.append( + { + "stack": stack, + "sampling_rate": rate, + "run": run, + "benchmark": benchmark, + "baseline_ops": base_ops, + "sdk_ops": sdk_ops, + "degradation_pct": degradation, + "reliable": reliable, + } + ) + + return rows + + +def write_csv(path: Path, rows: list[dict[str, object]], fieldnames: list[str]) -> None: + with path.open("w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + for row in rows: + writer.writerow(row) + + +def build_per_benchmark_summary(raw_rows: list[dict[str, object]]) -> list[dict[str, object]]: + grouped: dict[tuple[str, float, str], list[float]] = defaultdict(list) + for row in raw_rows: + degradation = row["degradation_pct"] + if row["reliable"] and isinstance(degradation, float): + key = (str(row["stack"]), float(row["sampling_rate"]), str(row["benchmark"])) + grouped[key].append(degradation) + + summary_rows: list[dict[str, object]] = [] + for (stack, rate, benchmark), values in grouped.items(): + summary_rows.append( + { + "stack": stack, + "sampling_rate": rate, + "benchmark": benchmark, + "samples": len(values), + "median_degradation_pct": statistics.median(values), + "p95_degradation_pct": percentile(values, 95.0), + "min_degradation_pct": min(values), + "max_degradation_pct": max(values), + } + ) + + summary_rows.sort(key=lambda r: (str(r["stack"]), float(r["sampling_rate"]), str(r["benchmark"]))) + return summary_rows + + +def build_stack_summary(raw_rows: list[dict[str, object]]) -> list[dict[str, object]]: + grouped: dict[tuple[str, float], list[float]] = defaultdict(list) + for row in raw_rows: + degradation = row["degradation_pct"] + if row["reliable"] and isinstance(degradation, float): + key = (str(row["stack"]), float(row["sampling_rate"])) + grouped[key].append(degradation) + + summary_rows: list[dict[str, object]] = [] + for (stack, rate), values in grouped.items(): + summary_rows.append( + { + "stack": stack, + "sampling_rate": rate, + "samples": len(values), + "median_degradation_pct": statistics.median(values), + "p95_degradation_pct": percentile(values, 95.0), + "min_degradation_pct": min(values), + "max_degradation_pct": max(values), + } + ) + + summary_rows.sort(key=lambda r: (str(r["stack"]), float(r["sampling_rate"]))) + return summary_rows + + +def build_recommendations(stack_summary_rows: list[dict[str, object]]) -> str: + tolerances = [1.0, 3.0, 5.0, 10.0] + by_stack: dict[str, list[dict[str, object]]] = defaultdict(list) + for row in stack_summary_rows: + by_stack[str(row["stack"])].append(row) + + lines = [ + "# Sampling Recommendations", + "", + "Recommended rate uses the highest sampling rate where `p95_degradation_pct <= tolerance`.", + "", + ] + + for stack in sorted(by_stack.keys()): + rows = sorted(by_stack[stack], key=lambda r: float(r["sampling_rate"])) + lines.append(f"## {stack}") + for tol in tolerances: + eligible = [r for r in rows if float(r["p95_degradation_pct"]) <= tol] + if eligible: + best = max(eligible, key=lambda r: float(r["sampling_rate"])) + lines.append( + f"- tolerance <= {tol:.0f}%: sampling_rate <= {best['sampling_rate']} " + f"(p95={float(best['p95_degradation_pct']):.2f}%, samples={best['samples']})" + ) + else: + lines.append(f"- tolerance <= {tol:.0f}%: no measured rate satisfies this bound") + lines.append("") + + return "\n".join(lines).rstrip() + "\n" + + +def maybe_render_plot(stack_summary_rows: list[dict[str, object]], output_dir: Path) -> Path | None: + try: + import matplotlib.pyplot as plt + except ImportError: + print("matplotlib not installed; skipping PNG plot generation") + return None + + by_stack: dict[str, list[dict[str, object]]] = defaultdict(list) + for row in stack_summary_rows: + by_stack[str(row["stack"])].append(row) + + if not by_stack: + return None + + fig, axes = plt.subplots(nrows=len(by_stack), ncols=1, figsize=(10, 4 * len(by_stack)), sharex=True) + if len(by_stack) == 1: + axes = [axes] + + for ax, stack in zip(axes, sorted(by_stack.keys())): + rows = sorted(by_stack[stack], key=lambda r: float(r["sampling_rate"])) + rates = [float(r["sampling_rate"]) for r in rows] + med = [float(r["median_degradation_pct"]) for r in rows] + p95 = [float(r["p95_degradation_pct"]) for r in rows] + + ax.plot(rates, med, marker="o", label="median degradation (%)") + ax.plot(rates, p95, marker="x", linestyle="--", label="p95 degradation (%)") + ax.axhline(1.0, color="gray", linestyle=":", linewidth=1) + ax.axhline(3.0, color="gray", linestyle=":", linewidth=1) + ax.axhline(5.0, color="gray", linestyle=":", linewidth=1) + ax.set_title(stack) + ax.set_ylabel("Degradation (%)") + ax.grid(True, alpha=0.25) + ax.legend(loc="best") + + axes[-1].set_xlabel("Sampling rate") + fig.suptitle("Sampling Rate vs Performance Degradation (stack summary)") + fig.tight_layout() + + plot_path = output_dir / "sampling-rate-vs-degradation.png" + fig.savefig(plot_path, dpi=150) + print(f"wrote plot: {plot_path}") + return plot_path + + +def main() -> int: + args = parse_args() + logs_dir = Path(args.logs_dir) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + log_files = sorted(logs_dir.glob("*_rate-*_run-*.log")) + if not log_files: + raise SystemExit(f"No log files found in {logs_dir}") + + raw_rows: list[dict[str, object]] = [] + for log_file in log_files: + raw_rows.extend(parse_log_file(log_file)) + + if not raw_rows: + raise SystemExit("No benchmark rows parsed. Check logs for expected benchmark output.") + + raw_rows.sort( + key=lambda r: ( + str(r["stack"]), + float(r["sampling_rate"]), + int(r["run"]), + str(r["benchmark"]), + ) + ) + + per_benchmark_summary = build_per_benchmark_summary(raw_rows) + stack_summary = build_stack_summary(raw_rows) + + raw_csv = output_dir / "sampling_benchmark_raw.csv" + per_bench_csv = output_dir / "sampling_benchmark_per_benchmark_summary.csv" + stack_csv = output_dir / "sampling_benchmark_stack_summary.csv" + rec_md = output_dir / "sampling_recommendations.md" + + write_csv( + raw_csv, + raw_rows, + [ + "stack", + "sampling_rate", + "run", + "benchmark", + "baseline_ops", + "sdk_ops", + "degradation_pct", + "reliable", + ], + ) + write_csv( + per_bench_csv, + per_benchmark_summary, + [ + "stack", + "sampling_rate", + "benchmark", + "samples", + "median_degradation_pct", + "p95_degradation_pct", + "min_degradation_pct", + "max_degradation_pct", + ], + ) + write_csv( + stack_csv, + stack_summary, + [ + "stack", + "sampling_rate", + "samples", + "median_degradation_pct", + "p95_degradation_pct", + "min_degradation_pct", + "max_degradation_pct", + ], + ) + + rec_md.write_text(build_recommendations(stack_summary), encoding="utf-8") + + maybe_render_plot(stack_summary, output_dir) + + print(f"wrote csv: {raw_csv}") + print(f"wrote csv: {per_bench_csv}") + print(f"wrote csv: {stack_csv}") + print(f"wrote markdown: {rec_md}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) + diff --git a/tests/unit/test_adaptive_sampling.py b/tests/unit/test_adaptive_sampling.py new file mode 100644 index 0000000..807473b --- /dev/null +++ b/tests/unit/test_adaptive_sampling.py @@ -0,0 +1,42 @@ +from drift.core.adaptive_sampling import ( + AdaptiveSamplingController, + AdaptiveSamplingHealthSnapshot, + ResolvedSamplingConfig, +) + + +def test_pre_app_start_always_records(): + controller = AdaptiveSamplingController( + ResolvedSamplingConfig(mode="adaptive", base_rate=0.0, min_rate=0.0), + random_fn=lambda: 0.99, + now_fn=lambda: 0.0, + ) + + decision = controller.get_decision(is_pre_app_start=True) + + assert decision.should_record is True + assert decision.reason == "pre_app_start" + assert decision.effective_rate == 1.0 + + +def test_controller_load_sheds_and_pauses_on_drops(): + now = {"value": 0.0} + controller = AdaptiveSamplingController( + ResolvedSamplingConfig(mode="adaptive", base_rate=0.5, min_rate=0.1), + random_fn=lambda: 0.3, + now_fn=lambda: now["value"], + ) + + controller.update(AdaptiveSamplingHealthSnapshot(queue_fill_ratio=0.9)) + load_shed_decision = controller.get_decision(is_pre_app_start=False) + assert load_shed_decision.state == "hot" + assert load_shed_decision.effective_rate < 0.5 + assert load_shed_decision.should_record is False + assert load_shed_decision.reason == "load_shed" + + now["value"] = 1.0 + controller.update(AdaptiveSamplingHealthSnapshot(queue_fill_ratio=0.1, dropped_span_count=1)) + paused_decision = controller.get_decision(is_pre_app_start=False) + assert paused_decision.state == "critical_pause" + assert paused_decision.should_record is False + assert paused_decision.reason == "critical_pause" diff --git a/tests/unit/test_config_loading.py b/tests/unit/test_config_loading.py index 9f1570f..8d6d1b5 100644 --- a/tests/unit/test_config_loading.py +++ b/tests/unit/test_config_loading.py @@ -214,6 +214,38 @@ def test_handles_partial_config(self): finally: os.chdir(original_cwd) + def test_loads_nested_sampling_config(self): + """Should load recording.sampling config alongside legacy fields.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + (project_root / "pyproject.toml").touch() + + tusk_dir = project_root / ".tusk" + tusk_dir.mkdir() + (tusk_dir / "config.yaml").write_text( + """ +recording: + sampling: + mode: adaptive + base_rate: 0.25 + min_rate: 0.05 +""" + ) + + original_cwd = os.getcwd() + try: + os.chdir(project_root) + config = load_tusk_config() + + assert config is not None + assert config.recording is not None + assert config.recording.sampling is not None + assert config.recording.sampling.mode == "adaptive" + assert config.recording.sampling.base_rate == 0.25 + assert config.recording.sampling.min_rate == 0.05 + finally: + os.chdir(original_cwd) + def test_handles_invalid_yaml(self): """Should return None when YAML is invalid.""" with tempfile.TemporaryDirectory() as tmpdir: diff --git a/tests/unit/test_mode_utils.py b/tests/unit/test_mode_utils.py index c60cb0f..be94c89 100644 --- a/tests/unit/test_mode_utils.py +++ b/tests/unit/test_mode_utils.py @@ -311,11 +311,9 @@ def test_returns_true_when_no_drop_and_sampled(self, mocker): """Should return (True, None) when not dropped and sampled.""" mock_drift = mocker.patch("drift.core.drift_sdk.TuskDrift") mock_sdk = mocker.MagicMock() - mock_sdk.get_sampling_rate.return_value = 1.0 + mock_sdk.should_record_root_request.return_value.should_record = True mock_drift.get_instance.return_value = mock_sdk - mocker.patch("drift.core.sampling.should_sample", return_value=True) - result, reason = should_record_inbound_http_request( method="GET", target="/api/users", @@ -361,11 +359,9 @@ def test_returns_false_when_not_sampled(self, mocker): """Should return (False, 'not_sampled') when sampling decides to skip.""" mock_drift = mocker.patch("drift.core.drift_sdk.TuskDrift") mock_sdk = mocker.MagicMock() - mock_sdk.get_sampling_rate.return_value = 0.0 + mock_sdk.should_record_root_request.return_value.should_record = False mock_drift.get_instance.return_value = mock_sdk - mocker.patch("drift.core.sampling.should_sample", return_value=False) - result, reason = should_record_inbound_http_request( method="GET", target="/api/users", @@ -382,7 +378,9 @@ def test_drop_check_happens_before_sampling(self, mocker): mock_transform = mocker.MagicMock() mock_transform.should_drop_inbound_request.return_value = True - mock_sample = mocker.patch("drift.core.sampling.should_sample") + mock_drift = mocker.patch("drift.core.drift_sdk.TuskDrift") + mock_sdk = mocker.MagicMock() + mock_drift.get_instance.return_value = mock_sdk result, reason = should_record_inbound_http_request( method="GET", @@ -392,7 +390,6 @@ def test_drop_check_happens_before_sampling(self, mocker): is_pre_app_start=False, ) - # should_sample should never be called if dropped - mock_sample.assert_not_called() + mock_sdk.should_record_root_request.assert_not_called() assert result is False assert reason == "dropped" diff --git a/tests/unit/test_span_utils.py b/tests/unit/test_span_utils.py index c23d0b2..1e1bf36 100644 --- a/tests/unit/test_span_utils.py +++ b/tests/unit/test_span_utils.py @@ -7,6 +7,7 @@ from opentelemetry.trace import SpanKind as OTelSpanKind from opentelemetry.trace import Status, StatusCode +from drift.core.no_recording import suppress_recording from drift.core.tracing.span_utils import ( AddSpanAttributesOptions, CreateSpanOptions, @@ -205,6 +206,23 @@ def test_returns_none_on_exception(self, mocker): assert result is None + def test_returns_none_when_recording_is_suppressed(self, mocker): + """Should not create spans when no-record context is active.""" + mock_drift = mocker.patch("drift.core.drift_sdk.TuskDrift") + mock_sdk = mocker.MagicMock() + mock_sdk.get_tracer.return_value = mocker.MagicMock() + mock_drift.get_instance.return_value = mock_sdk + + options = CreateSpanOptions( + name="test-span", + kind=OTelSpanKind.SERVER, + ) + + with suppress_recording(): + result = SpanUtils.create_span(options) + + assert result is None + class TestSpanUtilsWithSpan: """Tests for SpanUtils.with_span context manager.""" From 976a86e3bce991c9a3f4ce1b98d479031b1f2aab Mon Sep 17 00:00:00 2001 From: JY Tan Date: Fri, 10 Apr 2026 00:18:19 -0700 Subject: [PATCH 2/8] Fixes --- drift/core/adaptive_sampling.py | 231 +++++++++++++------------- drift/core/drift_sdk.py | 91 +++++++--- drift/instrumentation/wsgi/handler.py | 30 +++- tests/unit/test_adaptive_sampling.py | 61 +++++++ tests/unit/test_drift_sdk.py | 139 ++++++++++++++++ tests/unit/test_wsgi_handler.py | 79 +++++++++ 6 files changed, 494 insertions(+), 137 deletions(-) create mode 100644 tests/unit/test_wsgi_handler.py diff --git a/drift/core/adaptive_sampling.py b/drift/core/adaptive_sampling.py index c287ef2..2adc29e 100644 --- a/drift/core/adaptive_sampling.py +++ b/drift/core/adaptive_sampling.py @@ -5,6 +5,7 @@ import logging import math import random +import threading import time from dataclasses import dataclass from typing import Literal @@ -75,6 +76,7 @@ def __init__( self._config = config self._random_fn = random_fn self._now_fn = now_fn + self._lock = threading.RLock() self._admission_multiplier = 1.0 self._state: AdaptiveSamplingState = "fixed" if config.mode == "fixed" else "healthy" @@ -90,135 +92,140 @@ def __init__( self._recent_failure_signal = 0.0 def update(self, snapshot: AdaptiveSamplingHealthSnapshot) -> None: - if self._config.mode != "adaptive": - self._state = "fixed" - self._admission_multiplier = 1.0 - return - - now_s = self._now_fn() - elapsed_s = 2.0 if self._last_updated_at_s == 0 else max(0.001, now_s - self._last_updated_at_s) - self._last_updated_at_s = now_s - - decay = math.exp(-(elapsed_s * 1000.0) / 30000.0) - self._recent_drop_signal *= decay - self._recent_failure_signal *= decay - - dropped_delta = max(0, snapshot.dropped_span_count - self._prev_dropped_span_count) - export_failure_delta = max(0, snapshot.export_failure_count - self._prev_export_failure_count) - - self._prev_dropped_span_count = snapshot.dropped_span_count - self._prev_export_failure_count = snapshot.export_failure_count - - self._recent_drop_signal += dropped_delta - self._recent_failure_signal += export_failure_delta - - if snapshot.queue_fill_ratio is not None: - queue_fill_ratio = _clamp01(snapshot.queue_fill_ratio) - self._queue_fill_ewma = ( - queue_fill_ratio - if self._queue_fill_ewma is None - else (0.25 * queue_fill_ratio) + (0.75 * self._queue_fill_ewma) + with self._lock: + if self._config.mode != "adaptive": + self._state = "fixed" + self._admission_multiplier = 1.0 + return + + now_s = self._now_fn() + elapsed_s = 2.0 if self._last_updated_at_s == 0 else max(0.001, now_s - self._last_updated_at_s) + self._last_updated_at_s = now_s + + decay = math.exp(-(elapsed_s * 1000.0) / 30000.0) + self._recent_drop_signal *= decay + self._recent_failure_signal *= decay + + dropped_delta = max(0, snapshot.dropped_span_count - self._prev_dropped_span_count) + export_failure_delta = max(0, snapshot.export_failure_count - self._prev_export_failure_count) + + self._prev_dropped_span_count = snapshot.dropped_span_count + self._prev_export_failure_count = snapshot.export_failure_count + + self._recent_drop_signal += dropped_delta + self._recent_failure_signal += export_failure_delta + + if snapshot.queue_fill_ratio is not None: + queue_fill_ratio = _clamp01(snapshot.queue_fill_ratio) + self._queue_fill_ewma = ( + queue_fill_ratio + if self._queue_fill_ewma is None + else (0.25 * queue_fill_ratio) + (0.75 * self._queue_fill_ewma) + ) + + queue_pressure = _normalize_between(self._queue_fill_ewma, 0.20, 0.85) + memory_pressure = _normalize_between(snapshot.memory_pressure_ratio, 0.80, 0.92) + export_failure_pressure = _clamp01(self._recent_failure_signal / 5.0) + pressure = max(queue_pressure, memory_pressure, export_failure_pressure) + + hard_brake = ( + dropped_delta > 0 or snapshot.export_circuit_open or (snapshot.memory_pressure_ratio or 0.0) >= 0.92 ) - queue_pressure = _normalize_between(self._queue_fill_ewma, 0.20, 0.85) - memory_pressure = _normalize_between(snapshot.memory_pressure_ratio, 0.80, 0.92) - export_failure_pressure = _clamp01(self._recent_failure_signal / 5.0) - pressure = max(queue_pressure, memory_pressure, export_failure_pressure) + previous_state = self._state + previous_multiplier = self._admission_multiplier + + if hard_brake: + self._paused_until_s = now_s + 15.0 + self._admission_multiplier = 0.0 + self._state = "critical_pause" + self._last_decrease_at_s = now_s + self._log_transition(previous_state, previous_multiplier, pressure, snapshot) + return + + if now_s < self._paused_until_s: + self._state = "critical_pause" + self._log_transition(previous_state, previous_multiplier, pressure, snapshot) + return + + min_multiplier = self._get_min_multiplier() + if pressure >= 0.70: + self._admission_multiplier = max(min_multiplier, self._admission_multiplier * 0.4) + self._state = "hot" + self._last_decrease_at_s = now_s + elif pressure >= 0.45: + self._admission_multiplier = max(min_multiplier, self._admission_multiplier * 0.7) + self._state = "warm" + self._last_decrease_at_s = now_s + else: + if pressure <= 0.20 and (now_s - self._last_decrease_at_s) >= 10.0: + self._admission_multiplier = min(1.0, self._admission_multiplier + 0.05) + self._state = "healthy" - hard_brake = ( - dropped_delta > 0 or snapshot.export_circuit_open or (snapshot.memory_pressure_ratio or 0.0) >= 0.92 - ) - - previous_state = self._state - previous_multiplier = self._admission_multiplier - - if hard_brake: - self._paused_until_s = now_s + 15.0 - self._admission_multiplier = 0.0 - self._state = "critical_pause" - self._last_decrease_at_s = now_s self._log_transition(previous_state, previous_multiplier, pressure, snapshot) - return - - if now_s < self._paused_until_s: - self._state = "critical_pause" - self._log_transition(previous_state, previous_multiplier, pressure, snapshot) - return - - min_multiplier = self._get_min_multiplier() - if pressure >= 0.70: - self._admission_multiplier = max(min_multiplier, self._admission_multiplier * 0.4) - self._state = "hot" - self._last_decrease_at_s = now_s - elif pressure >= 0.45: - self._admission_multiplier = max(min_multiplier, self._admission_multiplier * 0.7) - self._state = "warm" - self._last_decrease_at_s = now_s - else: - if pressure <= 0.20 and (now_s - self._last_decrease_at_s) >= 10.0: - self._admission_multiplier = min(1.0, self._admission_multiplier + 0.05) - self._state = "healthy" - - self._log_transition(previous_state, previous_multiplier, pressure, snapshot) def get_decision(self, *, is_pre_app_start: bool) -> RootSamplingDecision: - if is_pre_app_start: - return RootSamplingDecision( - should_record=True, - reason="pre_app_start", - mode=self._config.mode, - state=self._state, - base_rate=self._config.base_rate, - min_rate=self._config.min_rate, - effective_rate=1.0, - admission_multiplier=1.0, + with self._lock: + if is_pre_app_start: + return RootSamplingDecision( + should_record=True, + reason="pre_app_start", + mode=self._config.mode, + state=self._state, + base_rate=self._config.base_rate, + min_rate=self._config.min_rate, + effective_rate=1.0, + admission_multiplier=1.0, + ) + + effective_rate = ( + self.get_effective_sampling_rate() + if self._config.mode == "adaptive" + else _clamp01(self._config.base_rate) ) - effective_rate = ( - self.get_effective_sampling_rate() if self._config.mode == "adaptive" else _clamp01(self._config.base_rate) - ) - - if effective_rate <= 0.0: + if effective_rate <= 0.0: + return RootSamplingDecision( + should_record=False, + reason="critical_pause" if self._state == "critical_pause" else "not_sampled", + mode=self._config.mode, + state=self._state, + base_rate=self._config.base_rate, + min_rate=self._config.min_rate, + effective_rate=effective_rate, + admission_multiplier=self._admission_multiplier, + ) + + should_record = self._random_fn() < effective_rate return RootSamplingDecision( - should_record=False, - reason="critical_pause" if self._state == "critical_pause" else "not_sampled", + should_record=should_record, + reason=( + "sampled" + if should_record + else "load_shed" + if self._config.mode == "adaptive" and effective_rate < self._config.base_rate + else "not_sampled" + ), mode=self._config.mode, state=self._state, base_rate=self._config.base_rate, min_rate=self._config.min_rate, effective_rate=effective_rate, - admission_multiplier=self._admission_multiplier, + admission_multiplier=self._admission_multiplier if self._config.mode == "adaptive" else 1.0, ) - should_record = self._random_fn() < effective_rate - return RootSamplingDecision( - should_record=should_record, - reason=( - "sampled" - if should_record - else "load_shed" - if self._config.mode == "adaptive" and effective_rate < self._config.base_rate - else "not_sampled" - ), - mode=self._config.mode, - state=self._state, - base_rate=self._config.base_rate, - min_rate=self._config.min_rate, - effective_rate=effective_rate, - admission_multiplier=self._admission_multiplier if self._config.mode == "adaptive" else 1.0, - ) - def get_effective_sampling_rate(self) -> float: - if self._config.mode != "adaptive": - return _clamp01(self._config.base_rate) - if self._state == "critical_pause" and self._now_fn() < self._paused_until_s: - return 0.0 - effective_rate = self._config.base_rate * self._admission_multiplier - return _clamp( - effective_rate, - min(self._config.base_rate, self._config.min_rate), - self._config.base_rate, - ) + with self._lock: + if self._config.mode != "adaptive": + return _clamp01(self._config.base_rate) + if self._state == "critical_pause" and self._now_fn() < self._paused_until_s: + return 0.0 + effective_rate = self._config.base_rate * self._admission_multiplier + return _clamp( + effective_rate, + min(self._config.base_rate, self._config.min_rate), + self._config.base_rate, + ) def _get_min_multiplier(self) -> float: if self._config.base_rate <= 0.0 or self._config.min_rate <= 0.0: diff --git a/drift/core/drift_sdk.py b/drift/core/drift_sdk.py index 936680f..1d2aa61 100644 --- a/drift/core/drift_sdk.py +++ b/drift/core/drift_sdk.py @@ -352,13 +352,14 @@ def _determine_sampling_config(self, init_param: float | None) -> ResolvedSampli config_sampling.mode, ) - base_rate = 1.0 + base_rate: float | None = None if init_param is not None: validated = validate_sampling_rate(init_param, "init params") if validated is not None: logger.debug(f"Using sampling rate from init params: {validated}") base_rate = validated - else: + + if base_rate is None: env_rate = os.environ.get("TUSK_SAMPLING_RATE") if env_rate is not None: try: @@ -369,20 +370,22 @@ def _determine_sampling_config(self, init_param: float | None) -> ResolvedSampli base_rate = validated except ValueError: logger.warning(f"Invalid TUSK_SAMPLING_RATE env var: {env_rate}") - elif config_sampling and config_sampling.base_rate is not None: - validated = validate_sampling_rate( - config_sampling.base_rate, "config file recording.sampling.base_rate" - ) - if validated is not None: - base_rate = validated - elif recording_config and recording_config.sampling_rate is not None: - validated = validate_sampling_rate( - recording_config.sampling_rate, "config file recording.sampling_rate" - ) - if validated is not None: - base_rate = validated - else: - logger.debug("Using default sampling rate: 1.0") + + if base_rate is None and config_sampling and config_sampling.base_rate is not None: + validated = validate_sampling_rate(config_sampling.base_rate, "config file recording.sampling.base_rate") + if validated is not None: + logger.debug(f"Using sampling rate from config file recording.sampling.base_rate: {validated}") + base_rate = validated + + if base_rate is None and recording_config and recording_config.sampling_rate is not None: + validated = validate_sampling_rate(recording_config.sampling_rate, "config file recording.sampling_rate") + if validated is not None: + logger.debug(f"Using sampling rate from config file recording.sampling_rate: {validated}") + base_rate = validated + + if base_rate is None: + logger.debug("Using default sampling rate: 1.0") + base_rate = 1.0 min_rate = 0.0 if mode == "adaptive": @@ -423,11 +426,17 @@ def _start_adaptive_sampling_control_loop(self) -> None: name="drift-adaptive-sampling", ) self._adaptive_sampling_thread.start() - self._update_adaptive_sampling_health() + self._safe_update_adaptive_sampling_health() def _adaptive_sampling_loop(self) -> None: while not self._adaptive_sampling_stop_event.wait(timeout=2.0): + self._safe_update_adaptive_sampling_health() + + def _safe_update_adaptive_sampling_health(self) -> None: + try: self._update_adaptive_sampling_health() + except Exception: + logger.error("Adaptive sampling health update failed; keeping previous controller state.", exc_info=True) def _update_adaptive_sampling_health(self) -> None: if self._adaptive_sampling_controller is None: @@ -484,14 +493,52 @@ def _get_memory_pressure_ratio(self) -> float | None: if cgroup_v1_current is not None: return cgroup_v1_current / self._effective_memory_limit_bytes + current_rss_bytes = self._read_current_rss_bytes() + if current_rss_bytes is not None: + return current_rss_bytes / self._effective_memory_limit_bytes + + return None + + @staticmethod + def _parse_proc_status_rss_bytes(raw_status: str) -> int | None: + for line in raw_status.splitlines(): + if not line.startswith("VmRSS:"): + continue + + parts = line.split() + if len(parts) < 3 or parts[2].lower() != "kb": + return None + + return int(parts[1]) * 1024 + + return None + + @staticmethod + def _parse_proc_statm_rss_bytes(raw_statm: str, page_size: int) -> int | None: + fields = raw_statm.split() + if len(fields) < 2: + return None + + return int(fields[1]) * page_size + + def _read_current_rss_bytes(self) -> int | None: try: - import resource + proc_status_path = Path("/proc/self/status") + if proc_status_path.exists(): + parsed = self._parse_proc_status_rss_bytes(proc_status_path.read_text()) + if parsed is not None: + return parsed + except Exception: + pass - rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss - rss_bytes = rss if platform.system() == "Darwin" else rss * 1024 - return rss_bytes / self._effective_memory_limit_bytes + try: + proc_statm_path = Path("/proc/self/statm") + if proc_statm_path.exists(): + return self._parse_proc_statm_rss_bytes(proc_statm_path.read_text(), int(os.sysconf("SC_PAGE_SIZE"))) except Exception: - return None + pass + + return None def _read_numeric_control_file(self, path: str) -> int | None: try: diff --git a/drift/instrumentation/wsgi/handler.py b/drift/instrumentation/wsgi/handler.py index a85b0a3..231e90f 100644 --- a/drift/instrumentation/wsgi/handler.py +++ b/drift/instrumentation/wsgi/handler.py @@ -9,7 +9,7 @@ import json import logging import time -from collections.abc import Iterable +from collections.abc import Iterable, Iterator from typing import TYPE_CHECKING, Any from opentelemetry import context as otel_context @@ -54,6 +54,30 @@ ) +class SuppressedResponseIterable(Iterable[bytes]): + """Keep no-record suppression active while a skipped WSGI response is consumed.""" + + def __init__(self, response: Iterable[bytes]): + self._response = response + + def __iter__(self) -> Iterator[bytes]: + with suppress_recording(): + iterator = iter(self._response) + while True: + try: + with suppress_recording(): + chunk = next(iterator) + except StopIteration: + return + yield chunk + + def close(self) -> None: + close_method = getattr(self._response, "close", None) + if close_method is not None: + with suppress_recording(): + close_method() + + def handle_wsgi_request( app: WSGIApplication, environ: WSGIEnvironment, @@ -226,8 +250,8 @@ def _create_and_handle_request( ) if not should_record: logger.debug(f"[WSGI] Skipping request ({skip_reason}), path={path}") - with suppress_recording(): - return original_wsgi_app(app, environ, start_response) + response = original_wsgi_app(app, environ, start_response) + return SuppressedResponseIterable(response) # Capture request body request_body = capture_request_body(environ) diff --git a/tests/unit/test_adaptive_sampling.py b/tests/unit/test_adaptive_sampling.py index 807473b..09f0459 100644 --- a/tests/unit/test_adaptive_sampling.py +++ b/tests/unit/test_adaptive_sampling.py @@ -1,3 +1,5 @@ +import threading + from drift.core.adaptive_sampling import ( AdaptiveSamplingController, AdaptiveSamplingHealthSnapshot, @@ -40,3 +42,62 @@ def test_controller_load_sheds_and_pauses_on_drops(): assert paused_decision.state == "critical_pause" assert paused_decision.should_record is False assert paused_decision.reason == "critical_pause" + + +def test_get_decision_waits_for_controller_lock(): + controller = AdaptiveSamplingController( + ResolvedSamplingConfig(mode="adaptive", base_rate=0.5, min_rate=0.1), + random_fn=lambda: 0.0, + now_fn=lambda: 0.0, + ) + started = threading.Event() + finished = threading.Event() + result = {} + + def worker() -> None: + started.set() + result["decision"] = controller.get_decision(is_pre_app_start=False) + finished.set() + + thread = threading.Thread(target=worker) + controller._lock.acquire() + try: + thread.start() + assert started.wait(timeout=1.0) + assert not finished.wait(timeout=0.05) + finally: + controller._lock.release() + + assert finished.wait(timeout=1.0) + thread.join(timeout=1.0) + assert not thread.is_alive() + assert result["decision"].effective_rate == 0.5 + + +def test_update_waits_for_controller_lock(): + controller = AdaptiveSamplingController( + ResolvedSamplingConfig(mode="adaptive", base_rate=0.5, min_rate=0.1), + random_fn=lambda: 0.0, + now_fn=lambda: 0.0, + ) + started = threading.Event() + finished = threading.Event() + + def worker() -> None: + started.set() + controller.update(AdaptiveSamplingHealthSnapshot(queue_fill_ratio=0.9)) + finished.set() + + thread = threading.Thread(target=worker) + controller._lock.acquire() + try: + thread.start() + assert started.wait(timeout=1.0) + assert not finished.wait(timeout=0.05) + finally: + controller._lock.release() + + assert finished.wait(timeout=1.0) + thread.join(timeout=1.0) + assert not thread.is_alive() + assert controller.get_decision(is_pre_app_start=False).state == "hot" diff --git a/tests/unit/test_drift_sdk.py b/tests/unit/test_drift_sdk.py index 954ae31..e7b1d3a 100644 --- a/tests/unit/test_drift_sdk.py +++ b/tests/unit/test_drift_sdk.py @@ -6,6 +6,7 @@ import pytest +from drift.core.config import RecordingConfig, SamplingConfig, TuskFileConfig from drift.core.drift_sdk import TuskDrift from drift.core.types import TuskDriftMode @@ -156,6 +157,30 @@ def test_init_param_takes_precedence_over_env_var(self, reset_singleton): assert result == 0.75 + def test_invalid_init_param_falls_back_to_env_var(self, reset_singleton): + """Should use env var when init param is present but invalid.""" + os.environ["TUSK_DRIFT_MODE"] = "DISABLED" + os.environ["TUSK_SAMPLING_RATE"] = "0.25" + instance = TuskDrift.get_instance() + + result = instance._determine_sampling_rate(2.0) + + assert result == 0.25 + + def test_invalid_init_param_falls_back_to_config_file(self, reset_singleton): + """Should use config file when init param is present but invalid.""" + os.environ["TUSK_DRIFT_MODE"] = "DISABLED" + instance = TuskDrift.get_instance() + instance.file_config = TuskFileConfig( + recording=RecordingConfig( + sampling=SamplingConfig(base_rate=0.4), + ) + ) + + result = instance._determine_sampling_rate(2.0) + + assert result == 0.4 + def test_defaults_to_1_0(self, reset_singleton): """Should default to 1.0 (100%) sampling rate.""" os.environ["TUSK_DRIFT_MODE"] = "DISABLED" @@ -175,6 +200,21 @@ def test_rejects_invalid_env_var_sampling_rate(self, reset_singleton): assert result == 1.0 + def test_invalid_env_var_falls_back_to_config_file(self, reset_singleton): + """Should use config file when env var is present but invalid.""" + os.environ["TUSK_DRIFT_MODE"] = "DISABLED" + os.environ["TUSK_SAMPLING_RATE"] = "invalid" + instance = TuskDrift.get_instance() + instance.file_config = TuskFileConfig( + recording=RecordingConfig( + sampling=SamplingConfig(base_rate=0.4), + ) + ) + + result = instance._determine_sampling_rate(None) + + assert result == 0.4 + class TestTuskDriftInitialize: """Tests for TuskDrift.initialize method.""" @@ -412,6 +452,105 @@ def test_shutdown_cleans_up_resources(self, reset_singleton, mocker): mock_tracer_provider.shutdown.assert_called_once() +class TestTuskDriftAdaptiveSampling: + """Tests for adaptive sampling health monitoring.""" + + @pytest.fixture(autouse=True) + def reset_singleton(self): + """Reset singleton state before each test.""" + TuskDrift._instance = None + TuskDrift._initialized = False + yield + TuskDrift._instance = None + TuskDrift._initialized = False + + def test_safe_update_logs_and_swallows_health_update_exceptions(self, reset_singleton, mocker): + """Should log and continue when health updates fail.""" + instance = TuskDrift.get_instance() + mocker.patch.object(instance, "_update_adaptive_sampling_health", side_effect=RuntimeError("boom")) + log_error = mocker.patch("drift.core.drift_sdk.logger.error") + + instance._safe_update_adaptive_sampling_health() + + log_error.assert_called_once() + assert "Adaptive sampling health update failed" in log_error.call_args.args[0] + + def test_adaptive_sampling_loop_continues_after_update_exception(self, reset_singleton, mocker): + """Should keep polling after a single health update failure.""" + instance = TuskDrift.get_instance() + stop_event = mocker.MagicMock() + stop_event.wait.side_effect = [False, False, True] + instance._adaptive_sampling_stop_event = stop_event + log_error = mocker.patch("drift.core.drift_sdk.logger.error") + + update_health = mocker.patch.object( + instance, + "_update_adaptive_sampling_health", + side_effect=[RuntimeError("boom"), None], + ) + + instance._adaptive_sampling_loop() + + assert update_health.call_count == 2 + log_error.assert_called_once() + assert "Adaptive sampling health update failed" in log_error.call_args.args[0] + + +class TestTuskDriftMemoryPressure: + """Tests for memory pressure measurement helpers.""" + + @pytest.fixture(autouse=True) + def reset_singleton(self): + """Reset singleton state before each test.""" + TuskDrift._instance = None + TuskDrift._initialized = False + yield + TuskDrift._instance = None + TuskDrift._initialized = False + + def test_parse_proc_status_rss_bytes(self, reset_singleton): + """Should parse current RSS from /proc/self/status.""" + raw_status = "Name:\tpython\nVmRSS:\t1234 kB\nThreads:\t8\n" + + assert TuskDrift._parse_proc_status_rss_bytes(raw_status) == 1234 * 1024 + + def test_read_current_rss_bytes_falls_back_to_proc_statm(self, reset_singleton, mocker): + """Should use /proc/self/statm when /proc/self/status is unavailable.""" + instance = TuskDrift.get_instance() + + mocker.patch( + "drift.core.drift_sdk.Path.exists", + autospec=True, + side_effect=lambda path: str(path) == "/proc/self/statm", + ) + mocker.patch( + "drift.core.drift_sdk.Path.read_text", + autospec=True, + side_effect=lambda path: "100 25 0 0 0 0 0\n" if str(path) == "/proc/self/statm" else "", + ) + mocker.patch("drift.core.drift_sdk.os.sysconf", return_value=4096) + + assert instance._read_current_rss_bytes() == 25 * 4096 + + def test_get_memory_pressure_ratio_uses_current_rss_fallback(self, reset_singleton, mocker): + """Should use current RSS fallback when cgroup current usage is unavailable.""" + instance = TuskDrift.get_instance() + instance._effective_memory_limit_bytes = 1024 + mocker.patch.object(instance, "_read_numeric_control_file", return_value=None) + mocker.patch.object(instance, "_read_current_rss_bytes", return_value=256) + + assert instance._get_memory_pressure_ratio() == 0.25 + + def test_get_memory_pressure_ratio_returns_none_without_current_measurement(self, reset_singleton, mocker): + """Should return None when no current memory measurement is available.""" + instance = TuskDrift.get_instance() + instance._effective_memory_limit_bytes = 1024 + mocker.patch.object(instance, "_read_numeric_control_file", return_value=None) + mocker.patch.object(instance, "_read_current_rss_bytes", return_value=None) + + assert instance._get_memory_pressure_ratio() is None + + class TestTuskDriftGetTracer: """Tests for TuskDrift.get_tracer method.""" diff --git a/tests/unit/test_wsgi_handler.py b/tests/unit/test_wsgi_handler.py new file mode 100644 index 0000000..f13e65b --- /dev/null +++ b/tests/unit/test_wsgi_handler.py @@ -0,0 +1,79 @@ +"""Tests for WSGI handler request lifecycle behavior.""" + +from __future__ import annotations + +from collections.abc import Iterable, Iterator +from typing import Any + +from drift.instrumentation.wsgi.handler import _create_and_handle_request + + +class StreamingResponse(Iterable[bytes]): + def __init__(self, observed: list[tuple[str, bool]]) -> None: + self._observed = observed + self._yielded = False + + def __iter__(self) -> Iterator[bytes]: + from drift.core.no_recording import is_recording_suppressed + + self._observed.append(("iter", is_recording_suppressed())) + return self + + def __next__(self) -> bytes: + from drift.core.no_recording import is_recording_suppressed + + self._observed.append(("next", is_recording_suppressed())) + if self._yielded: + raise StopIteration + self._yielded = True + return b"chunk" + + def close(self) -> None: + from drift.core.no_recording import is_recording_suppressed + + self._observed.append(("close", is_recording_suppressed())) + + +def test_skipped_wsgi_request_keeps_suppression_during_streaming_iteration_and_close(mocker) -> None: + observed: list[tuple[str, bool]] = [] + response = StreamingResponse(observed) + + mocker.patch( + "drift.instrumentation.wsgi.handler.should_record_inbound_http_request", + return_value=(False, "not_sampled"), + ) + + def original_wsgi_app(_app: Any, _environ: dict[str, Any], _start_response: Any) -> Iterable[bytes]: + return response + + def app(_environ: dict[str, Any], _start_response: Any) -> Iterable[bytes]: + return response + + wrapped_response = _create_and_handle_request( + app=app, + environ={ + "REQUEST_METHOD": "GET", + "PATH_INFO": "/stream", + "QUERY_STRING": "", + }, + start_response=lambda status, headers, exc_info=None: None, + original_wsgi_app=original_wsgi_app, + framework_name="wsgi", + instrumentation_name="WsgiInstrumentation", + transform_engine=None, + sdk=object(), + is_pre_app_start=False, + replay_token=None, + ) + + assert list(wrapped_response) == [b"chunk"] + close_method = getattr(wrapped_response, "close", None) + assert close_method is not None + close_method() + + assert observed == [ + ("iter", True), + ("next", True), + ("next", True), + ("close", True), + ] From b5118df821a5dc19fc19c7572b03b92c20ec4b6f Mon Sep 17 00:00:00 2001 From: JY Tan Date: Fri, 10 Apr 2026 13:29:34 -0700 Subject: [PATCH 3/8] More fixes --- drift/core/adaptive_sampling.py | 4 ++-- drift/core/mode_utils.py | 5 +++-- drift/instrumentation/wsgi/handler.py | 3 ++- tests/unit/test_adaptive_sampling.py | 17 +++++++++++++++++ tests/unit/test_mode_utils.py | 20 ++++++++++++++++++++ tests/unit/test_wsgi_handler.py | 4 ++++ 6 files changed, 48 insertions(+), 5 deletions(-) diff --git a/drift/core/adaptive_sampling.py b/drift/core/adaptive_sampling.py index 2adc29e..ec6710b 100644 --- a/drift/core/adaptive_sampling.py +++ b/drift/core/adaptive_sampling.py @@ -81,7 +81,7 @@ def __init__( self._admission_multiplier = 1.0 self._state: AdaptiveSamplingState = "fixed" if config.mode == "fixed" else "healthy" self._paused_until_s = 0.0 - self._last_updated_at_s = 0.0 + self._last_updated_at_s: float | None = None self._last_decrease_at_s = 0.0 self._prev_dropped_span_count = 0 @@ -99,7 +99,7 @@ def update(self, snapshot: AdaptiveSamplingHealthSnapshot) -> None: return now_s = self._now_fn() - elapsed_s = 2.0 if self._last_updated_at_s == 0 else max(0.001, now_s - self._last_updated_at_s) + elapsed_s = 2.0 if self._last_updated_at_s is None else max(0.001, now_s - self._last_updated_at_s) self._last_updated_at_s = now_s decay = math.exp(-(elapsed_s * 1000.0) / 30000.0) diff --git a/drift/core/mode_utils.py b/drift/core/mode_utils.py index 0a4bc40..f671a26 100644 --- a/drift/core/mode_utils.py +++ b/drift/core/mode_utils.py @@ -180,7 +180,8 @@ def should_record_inbound_http_request( Returns: Tuple of (should_record, skip_reason): - should_record: True if request should be recorded - - skip_reason: If False, explains why ("dropped" or "not_sampled"), None otherwise + - skip_reason: If False, explains why ("dropped", "not_sampled", + "load_shed", or "critical_pause"), None otherwise """ if transform_engine and transform_engine.should_drop_inbound_request(method, target, headers): return False, "dropped" @@ -191,6 +192,6 @@ def should_record_inbound_http_request( sdk = TuskDrift.get_instance() decision = sdk.should_record_root_request(is_pre_app_start=is_pre_app_start) if not decision.should_record: - return False, "not_sampled" + return False, decision.reason return True, None diff --git a/drift/instrumentation/wsgi/handler.py b/drift/instrumentation/wsgi/handler.py index 231e90f..4428e83 100644 --- a/drift/instrumentation/wsgi/handler.py +++ b/drift/instrumentation/wsgi/handler.py @@ -250,7 +250,8 @@ def _create_and_handle_request( ) if not should_record: logger.debug(f"[WSGI] Skipping request ({skip_reason}), path={path}") - response = original_wsgi_app(app, environ, start_response) + with suppress_recording(): + response = original_wsgi_app(app, environ, start_response) return SuppressedResponseIterable(response) # Capture request body diff --git a/tests/unit/test_adaptive_sampling.py b/tests/unit/test_adaptive_sampling.py index 09f0459..1f4d451 100644 --- a/tests/unit/test_adaptive_sampling.py +++ b/tests/unit/test_adaptive_sampling.py @@ -1,3 +1,4 @@ +import math import threading from drift.core.adaptive_sampling import ( @@ -44,6 +45,22 @@ def test_controller_load_sheds_and_pauses_on_drops(): assert paused_decision.reason == "critical_pause" +def test_elapsed_time_uses_zero_timestamp_as_real_value(): + now = {"value": 0.0} + controller = AdaptiveSamplingController( + ResolvedSamplingConfig(mode="adaptive", base_rate=0.5, min_rate=0.1), + random_fn=lambda: 0.0, + now_fn=lambda: now["value"], + ) + + controller.update(AdaptiveSamplingHealthSnapshot(export_failure_count=1)) + now["value"] = 0.5 + controller.update(AdaptiveSamplingHealthSnapshot(export_failure_count=1)) + + expected_decay = math.exp(-(0.5 * 1000.0) / 30000.0) + assert math.isclose(controller._recent_failure_signal, expected_decay, rel_tol=1e-6) + + def test_get_decision_waits_for_controller_lock(): controller = AdaptiveSamplingController( ResolvedSamplingConfig(mode="adaptive", base_rate=0.5, min_rate=0.1), diff --git a/tests/unit/test_mode_utils.py b/tests/unit/test_mode_utils.py index be94c89..3f9b001 100644 --- a/tests/unit/test_mode_utils.py +++ b/tests/unit/test_mode_utils.py @@ -360,6 +360,7 @@ def test_returns_false_when_not_sampled(self, mocker): mock_drift = mocker.patch("drift.core.drift_sdk.TuskDrift") mock_sdk = mocker.MagicMock() mock_sdk.should_record_root_request.return_value.should_record = False + mock_sdk.should_record_root_request.return_value.reason = "not_sampled" mock_drift.get_instance.return_value = mock_sdk result, reason = should_record_inbound_http_request( @@ -373,6 +374,25 @@ def test_returns_false_when_not_sampled(self, mocker): assert result is False assert reason == "not_sampled" + def test_returns_controller_reason_when_adaptive_sampling_skips(self, mocker): + """Should preserve adaptive controller reasons for debug logging.""" + mock_drift = mocker.patch("drift.core.drift_sdk.TuskDrift") + mock_sdk = mocker.MagicMock() + mock_sdk.should_record_root_request.return_value.should_record = False + mock_sdk.should_record_root_request.return_value.reason = "critical_pause" + mock_drift.get_instance.return_value = mock_sdk + + result, reason = should_record_inbound_http_request( + method="GET", + target="/api/users", + headers={}, + transform_engine=None, + is_pre_app_start=False, + ) + + assert result is False + assert reason == "critical_pause" + def test_drop_check_happens_before_sampling(self, mocker): """Should check drop rules before sampling.""" mock_transform = mocker.MagicMock() diff --git a/tests/unit/test_wsgi_handler.py b/tests/unit/test_wsgi_handler.py index f13e65b..0850852 100644 --- a/tests/unit/test_wsgi_handler.py +++ b/tests/unit/test_wsgi_handler.py @@ -44,6 +44,9 @@ def test_skipped_wsgi_request_keeps_suppression_during_streaming_iteration_and_c ) def original_wsgi_app(_app: Any, _environ: dict[str, Any], _start_response: Any) -> Iterable[bytes]: + from drift.core.no_recording import is_recording_suppressed + + observed.append(("call", is_recording_suppressed())) return response def app(_environ: dict[str, Any], _start_response: Any) -> Iterable[bytes]: @@ -72,6 +75,7 @@ def app(_environ: dict[str, Any], _start_response: Any) -> Iterable[bytes]: close_method() assert observed == [ + ("call", True), ("iter", True), ("next", True), ("next", True), From e4a89e6773f58b2215a501ef751a3810d07d5078 Mon Sep 17 00:00:00 2001 From: JY Tan Date: Fri, 10 Apr 2026 13:36:27 -0700 Subject: [PATCH 4/8] Remove gitignore entry and irrelevant script --- .gitignore | 3 - scripts/plot_sampling_benchmarks.py | 338 ---------------------------- 2 files changed, 341 deletions(-) delete mode 100755 scripts/plot_sampling_benchmarks.py diff --git a/.gitignore b/.gitignore index 297301c..1c1425a 100644 --- a/.gitignore +++ b/.gitignore @@ -220,9 +220,6 @@ __marimo__/ **/.tusk/traces/ **/.tusk/logs/ -# Bug tracking -**/BUG_TRACKING.md - # macOS .DS_Store diff --git a/scripts/plot_sampling_benchmarks.py b/scripts/plot_sampling_benchmarks.py deleted file mode 100755 index 5380f83..0000000 --- a/scripts/plot_sampling_benchmarks.py +++ /dev/null @@ -1,338 +0,0 @@ -#!/usr/bin/env python3 -"""Parse stack sampling benchmark logs and generate analysis artifacts.""" - -from __future__ import annotations - -import argparse -import csv -import math -import re -import statistics -from collections import defaultdict -from pathlib import Path - -BENCHMARK_LINE_RE = re.compile( - r"(Benchmark_\S+)\s+\d+\s+\d+\s+ns/op\s+([\d.]+)\s+ops/s(\s+\(~\))?" -) -LOG_NAME_RE = re.compile(r"(?P.+)_rate-(?P[0-9.]+)_run-(?P\d+)\.log$") - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="Build sampling-rate degradation datasets and plots from benchmark logs." - ) - parser.add_argument("--logs-dir", required=True, help="Directory containing *_rate-*_run-*.log files") - parser.add_argument("--output-dir", required=True, help="Directory to write CSV/plot artifacts") - return parser.parse_args() - - -def percentile(values: list[float], p: float) -> float: - if not values: - raise ValueError("percentile() requires non-empty values") - if len(values) == 1: - return values[0] - sorted_vals = sorted(values) - rank = (len(sorted_vals) - 1) * (p / 100.0) - lo = int(math.floor(rank)) - hi = int(math.ceil(rank)) - if lo == hi: - return sorted_vals[lo] - frac = rank - lo - return sorted_vals[lo] * (1.0 - frac) + sorted_vals[hi] * frac - - -def parse_log_file(path: Path) -> list[dict[str, object]]: - match = LOG_NAME_RE.match(path.name) - if not match: - return [] - - stack = match.group("stack") - rate = float(match.group("rate")) - run = int(match.group("run")) - - baseline: dict[str, tuple[float, bool]] = {} - sdk: dict[str, tuple[float, bool]] = {} - section: str | None = None - - with path.open("r", encoding="utf-8") as f: - for raw_line in f: - line = raw_line.strip() - - if "BASELINE (SDK DISABLED)" in line: - section = "baseline" - continue - if "WITH SDK (TUSK_DRIFT_MODE=RECORD)" in line: - section = "sdk" - continue - - m = BENCHMARK_LINE_RE.match(line) - if not m or section is None: - continue - - benchmark_name = m.group(1) - ops = float(m.group(2)) - reliable = m.group(3) is None - entry = (ops, reliable) - - if section == "baseline": - baseline[benchmark_name] = entry - elif section == "sdk": - sdk[benchmark_name] = entry - - rows: list[dict[str, object]] = [] - all_benchmarks = sorted(set(baseline.keys()) | set(sdk.keys())) - for benchmark in all_benchmarks: - base_entry = baseline.get(benchmark) - sdk_entry = sdk.get(benchmark) - - base_ops = base_entry[0] if base_entry else None - sdk_ops = sdk_entry[0] if sdk_entry else None - reliable = bool(base_entry and sdk_entry and base_entry[1] and sdk_entry[1]) - - degradation = None - if base_ops is not None and sdk_ops is not None and base_ops > 0: - # Positive means slower with SDK. - degradation = ((base_ops - sdk_ops) / base_ops) * 100.0 - - rows.append( - { - "stack": stack, - "sampling_rate": rate, - "run": run, - "benchmark": benchmark, - "baseline_ops": base_ops, - "sdk_ops": sdk_ops, - "degradation_pct": degradation, - "reliable": reliable, - } - ) - - return rows - - -def write_csv(path: Path, rows: list[dict[str, object]], fieldnames: list[str]) -> None: - with path.open("w", encoding="utf-8", newline="") as f: - writer = csv.DictWriter(f, fieldnames=fieldnames) - writer.writeheader() - for row in rows: - writer.writerow(row) - - -def build_per_benchmark_summary(raw_rows: list[dict[str, object]]) -> list[dict[str, object]]: - grouped: dict[tuple[str, float, str], list[float]] = defaultdict(list) - for row in raw_rows: - degradation = row["degradation_pct"] - if row["reliable"] and isinstance(degradation, float): - key = (str(row["stack"]), float(row["sampling_rate"]), str(row["benchmark"])) - grouped[key].append(degradation) - - summary_rows: list[dict[str, object]] = [] - for (stack, rate, benchmark), values in grouped.items(): - summary_rows.append( - { - "stack": stack, - "sampling_rate": rate, - "benchmark": benchmark, - "samples": len(values), - "median_degradation_pct": statistics.median(values), - "p95_degradation_pct": percentile(values, 95.0), - "min_degradation_pct": min(values), - "max_degradation_pct": max(values), - } - ) - - summary_rows.sort(key=lambda r: (str(r["stack"]), float(r["sampling_rate"]), str(r["benchmark"]))) - return summary_rows - - -def build_stack_summary(raw_rows: list[dict[str, object]]) -> list[dict[str, object]]: - grouped: dict[tuple[str, float], list[float]] = defaultdict(list) - for row in raw_rows: - degradation = row["degradation_pct"] - if row["reliable"] and isinstance(degradation, float): - key = (str(row["stack"]), float(row["sampling_rate"])) - grouped[key].append(degradation) - - summary_rows: list[dict[str, object]] = [] - for (stack, rate), values in grouped.items(): - summary_rows.append( - { - "stack": stack, - "sampling_rate": rate, - "samples": len(values), - "median_degradation_pct": statistics.median(values), - "p95_degradation_pct": percentile(values, 95.0), - "min_degradation_pct": min(values), - "max_degradation_pct": max(values), - } - ) - - summary_rows.sort(key=lambda r: (str(r["stack"]), float(r["sampling_rate"]))) - return summary_rows - - -def build_recommendations(stack_summary_rows: list[dict[str, object]]) -> str: - tolerances = [1.0, 3.0, 5.0, 10.0] - by_stack: dict[str, list[dict[str, object]]] = defaultdict(list) - for row in stack_summary_rows: - by_stack[str(row["stack"])].append(row) - - lines = [ - "# Sampling Recommendations", - "", - "Recommended rate uses the highest sampling rate where `p95_degradation_pct <= tolerance`.", - "", - ] - - for stack in sorted(by_stack.keys()): - rows = sorted(by_stack[stack], key=lambda r: float(r["sampling_rate"])) - lines.append(f"## {stack}") - for tol in tolerances: - eligible = [r for r in rows if float(r["p95_degradation_pct"]) <= tol] - if eligible: - best = max(eligible, key=lambda r: float(r["sampling_rate"])) - lines.append( - f"- tolerance <= {tol:.0f}%: sampling_rate <= {best['sampling_rate']} " - f"(p95={float(best['p95_degradation_pct']):.2f}%, samples={best['samples']})" - ) - else: - lines.append(f"- tolerance <= {tol:.0f}%: no measured rate satisfies this bound") - lines.append("") - - return "\n".join(lines).rstrip() + "\n" - - -def maybe_render_plot(stack_summary_rows: list[dict[str, object]], output_dir: Path) -> Path | None: - try: - import matplotlib.pyplot as plt - except ImportError: - print("matplotlib not installed; skipping PNG plot generation") - return None - - by_stack: dict[str, list[dict[str, object]]] = defaultdict(list) - for row in stack_summary_rows: - by_stack[str(row["stack"])].append(row) - - if not by_stack: - return None - - fig, axes = plt.subplots(nrows=len(by_stack), ncols=1, figsize=(10, 4 * len(by_stack)), sharex=True) - if len(by_stack) == 1: - axes = [axes] - - for ax, stack in zip(axes, sorted(by_stack.keys())): - rows = sorted(by_stack[stack], key=lambda r: float(r["sampling_rate"])) - rates = [float(r["sampling_rate"]) for r in rows] - med = [float(r["median_degradation_pct"]) for r in rows] - p95 = [float(r["p95_degradation_pct"]) for r in rows] - - ax.plot(rates, med, marker="o", label="median degradation (%)") - ax.plot(rates, p95, marker="x", linestyle="--", label="p95 degradation (%)") - ax.axhline(1.0, color="gray", linestyle=":", linewidth=1) - ax.axhline(3.0, color="gray", linestyle=":", linewidth=1) - ax.axhline(5.0, color="gray", linestyle=":", linewidth=1) - ax.set_title(stack) - ax.set_ylabel("Degradation (%)") - ax.grid(True, alpha=0.25) - ax.legend(loc="best") - - axes[-1].set_xlabel("Sampling rate") - fig.suptitle("Sampling Rate vs Performance Degradation (stack summary)") - fig.tight_layout() - - plot_path = output_dir / "sampling-rate-vs-degradation.png" - fig.savefig(plot_path, dpi=150) - print(f"wrote plot: {plot_path}") - return plot_path - - -def main() -> int: - args = parse_args() - logs_dir = Path(args.logs_dir) - output_dir = Path(args.output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - log_files = sorted(logs_dir.glob("*_rate-*_run-*.log")) - if not log_files: - raise SystemExit(f"No log files found in {logs_dir}") - - raw_rows: list[dict[str, object]] = [] - for log_file in log_files: - raw_rows.extend(parse_log_file(log_file)) - - if not raw_rows: - raise SystemExit("No benchmark rows parsed. Check logs for expected benchmark output.") - - raw_rows.sort( - key=lambda r: ( - str(r["stack"]), - float(r["sampling_rate"]), - int(r["run"]), - str(r["benchmark"]), - ) - ) - - per_benchmark_summary = build_per_benchmark_summary(raw_rows) - stack_summary = build_stack_summary(raw_rows) - - raw_csv = output_dir / "sampling_benchmark_raw.csv" - per_bench_csv = output_dir / "sampling_benchmark_per_benchmark_summary.csv" - stack_csv = output_dir / "sampling_benchmark_stack_summary.csv" - rec_md = output_dir / "sampling_recommendations.md" - - write_csv( - raw_csv, - raw_rows, - [ - "stack", - "sampling_rate", - "run", - "benchmark", - "baseline_ops", - "sdk_ops", - "degradation_pct", - "reliable", - ], - ) - write_csv( - per_bench_csv, - per_benchmark_summary, - [ - "stack", - "sampling_rate", - "benchmark", - "samples", - "median_degradation_pct", - "p95_degradation_pct", - "min_degradation_pct", - "max_degradation_pct", - ], - ) - write_csv( - stack_csv, - stack_summary, - [ - "stack", - "sampling_rate", - "samples", - "median_degradation_pct", - "p95_degradation_pct", - "min_degradation_pct", - "max_degradation_pct", - ], - ) - - rec_md.write_text(build_recommendations(stack_summary), encoding="utf-8") - - maybe_render_plot(stack_summary, output_dir) - - print(f"wrote csv: {raw_csv}") - print(f"wrote csv: {per_bench_csv}") - print(f"wrote csv: {stack_csv}") - print(f"wrote markdown: {rec_md}") - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) - From 2be621e9199236204e239d555b67a4c91fb93bab Mon Sep 17 00:00:00 2001 From: JY Tan Date: Fri, 10 Apr 2026 13:40:42 -0700 Subject: [PATCH 5/8] Cleanup --- drift/core/drift_sdk.py | 6 ---- drift/core/tracing/td_span_processor.py | 36 ++++------------------- tests/unit/test_td_span_processor.py | 39 +------------------------ 3 files changed, 7 insertions(+), 74 deletions(-) diff --git a/drift/core/drift_sdk.py b/drift/core/drift_sdk.py index 1d2aa61..14f7b90 100644 --- a/drift/core/drift_sdk.py +++ b/drift/core/drift_sdk.py @@ -299,8 +299,6 @@ def initialize( instance._td_span_processor = TdSpanProcessor( exporter=instance.span_exporter, mode=instance.mode, - sampling_rate=instance._sampling_rate, - app_ready=instance.app_ready, environment=env, ) instance._td_span_processor.start() @@ -848,10 +846,6 @@ def mark_app_as_ready(self) -> None: self.app_ready = True - # Update span processor with app_ready flag - if self._td_span_processor: - self._td_span_processor.update_app_ready(True) - if self.mode == TuskDriftMode.REPLAY: logger.debug("Replay mode active - ready to serve mocked responses") elif self.mode == TuskDriftMode.RECORD: diff --git a/drift/core/tracing/td_span_processor.py b/drift/core/tracing/td_span_processor.py index ef1e6e7..98adc6d 100644 --- a/drift/core/tracing/td_span_processor.py +++ b/drift/core/tracing/td_span_processor.py @@ -47,21 +47,21 @@ class TdSpanProcessor(SpanProcessor): This processor implements OpenTelemetry's SpanProcessor interface and serves as the bridge between OTel's tracing system and Drift's export infrastructure. + Root-request admission sampling happens earlier in inbound instrumentations, + so this processor only handles ended spans that were already allowed through. When a span ends: 1. Convert to CleanSpanData using otel_converter - 2. Apply sampling logic - 3. Apply trace blocking logic - 4. Validate protobuf serialization - 5. Forward to batch processor for export + 2. Apply trace blocking logic + 3. Validate protobuf serialization + 4. Handle REPLAY-mode inbound span forwarding + 5. Forward RECORD-mode spans to the batch processor for export """ def __init__( self, exporter: TdSpanExporter, mode: TuskDriftMode, - sampling_rate: float = 1.0, - app_ready: bool = False, environment: str | None = None, ) -> None: """Initialize the TdSpanProcessor. @@ -69,14 +69,10 @@ def __init__( Args: exporter: The TdSpanExporter to use for span export mode: SDK mode (RECORD, REPLAY, DISABLED) - sampling_rate: Sampling rate (0.0-1.0) - app_ready: Whether the application is ready environment: Environment name to include on spans """ self._exporter = exporter self._mode = mode - self._sampling_rate = sampling_rate - self._app_ready = app_ready self._environment = environment # We'll import and create batch processor lazily to avoid circular imports @@ -244,23 +240,3 @@ def force_flush(self, timeout_millis: int = 30000) -> bool: except Exception as e: logger.error(f"Error during force_flush: {e}") return False - - def update_app_ready(self, app_ready: bool) -> None: - """Update the app_ready flag. - - This is called when the application marks itself as ready. - - Args: - app_ready: Whether the application is ready - """ - self._app_ready = app_ready - logger.debug(f"TdSpanProcessor app_ready updated to {app_ready}") - - def update_sampling_rate(self, sampling_rate: float) -> None: - """Update the sampling rate. - - Args: - sampling_rate: New sampling rate (0.0-1.0) - """ - self._sampling_rate = sampling_rate - logger.debug(f"TdSpanProcessor sampling_rate updated to {sampling_rate}") diff --git a/tests/unit/test_td_span_processor.py b/tests/unit/test_td_span_processor.py index eae2f61..5a34c3c 100644 --- a/tests/unit/test_td_span_processor.py +++ b/tests/unit/test_td_span_processor.py @@ -23,8 +23,7 @@ def test_initializes_with_required_params(self, mocker): assert processor._exporter is mock_exporter assert processor._mode == TuskDriftMode.RECORD - assert processor._sampling_rate == 1.0 - assert processor._app_ready is False + assert processor._environment is None assert processor._started is False def test_initializes_with_optional_params(self, mocker): @@ -34,14 +33,10 @@ def test_initializes_with_optional_params(self, mocker): processor = TdSpanProcessor( exporter=mock_exporter, mode=TuskDriftMode.REPLAY, - sampling_rate=0.5, - app_ready=True, environment="production", ) assert processor._mode == TuskDriftMode.REPLAY - assert processor._sampling_rate == 0.5 - assert processor._app_ready is True assert processor._environment == "production" @@ -392,35 +387,3 @@ def test_force_flush_handles_exception(self, mocker): result = processor.force_flush() assert result is False - - -class TestTdSpanProcessorUpdateMethods: - """Tests for TdSpanProcessor update methods.""" - - def test_update_app_ready(self, mocker): - """Should update app_ready flag.""" - mock_exporter = mocker.MagicMock() - processor = TdSpanProcessor( - exporter=mock_exporter, - mode=TuskDriftMode.RECORD, - ) - - assert processor._app_ready is False - - processor.update_app_ready(True) - - assert processor._app_ready is True - - def test_update_sampling_rate(self, mocker): - """Should update sampling rate.""" - mock_exporter = mocker.MagicMock() - processor = TdSpanProcessor( - exporter=mock_exporter, - mode=TuskDriftMode.RECORD, - ) - - assert processor._sampling_rate == 1.0 - - processor.update_sampling_rate(0.5) - - assert processor._sampling_rate == 0.5 From 0d4844d73e49b6cf5f55a96df32190e969724adc Mon Sep 17 00:00:00 2001 From: JY Tan Date: Fri, 10 Apr 2026 13:46:48 -0700 Subject: [PATCH 6/8] Update docs --- docs/environment-variables.md | 9 ++-- docs/initialization.md | 86 +++++++++++++++++++++++++++++------ docs/quickstart.md | 13 +++++- 3 files changed, 88 insertions(+), 20 deletions(-) diff --git a/docs/environment-variables.md b/docs/environment-variables.md index 0425f1b..4fa72be 100644 --- a/docs/environment-variables.md +++ b/docs/environment-variables.md @@ -112,11 +112,12 @@ Your Tusk Drift API key, required when using Tusk Cloud for storing and managing ## TUSK_SAMPLING_RATE -Controls what percentage of requests are recorded during trace collection. +Controls the base recording rate used during trace collection. - **Type:** Number between 0.0 and 1.0 -- **Default:** 1.0 (100% of requests) -- **Precedence:** This environment variable is overridden by the `sampling_rate` parameter in `TuskDrift.initialize()`, but takes precedence over the `sampling_rate` setting in `.tusk/config.yaml` +- **If unset:** Falls back to `.tusk/config.yaml` and then the default base rate of `1.0` +- **Precedence:** This environment variable is overridden by the `sampling_rate` parameter in `TuskDrift.initialize()`, but takes precedence over `recording.sampling.base_rate` and the legacy `recording.sampling_rate` setting in `.tusk/config.yaml` +- **Scope:** This only overrides the base rate. It does not change `recording.sampling.mode` or `recording.sampling.min_rate` **Examples:** @@ -128,6 +129,8 @@ TUSK_SAMPLING_RATE=1.0 python app.py TUSK_SAMPLING_RATE=0.1 python app.py ``` +If `recording.sampling.mode: adaptive` is enabled in `.tusk/config.yaml`, this environment variable still only changes the base rate; adaptive load shedding remains active. + For more details on sampling rate configuration methods and precedence, see the [Initialization Guide](./initialization.md#configure-sampling-rate). ## Rust Core Flags diff --git a/docs/initialization.md b/docs/initialization.md index f8c707e..0414b79 100644 --- a/docs/initialization.md +++ b/docs/initialization.md @@ -73,8 +73,8 @@ Create an initialization file or add the SDK initialization to your application sampling_rate float - 1.0 - Override sampling rate (0.0 - 1.0) for recording. Takes precedence over TUSK_SAMPLING_RATE env var and config file. + None + Override the base sampling rate (0.0 - 1.0) for recording. Takes precedence over TUSK_SAMPLING_RATE and config file base-rate settings. Does not change recording.sampling.mode. @@ -172,28 +172,42 @@ if __name__ == "__main__": ## Configure Sampling Rate -The sampling rate determines what percentage of requests are recorded during replay tests. Tusk Drift supports three ways to configure the sampling rate, with the following precedence (highest to lowest): +Sampling controls what percentage of inbound requests are recorded in `RECORD` mode. + +Tusk Drift supports two sampling modes in `.tusk/config.yaml`: + +- `fixed`: record requests at a constant base rate. +- `adaptive`: start from a base rate and automatically shed load when queue pressure, export failures, or memory pressure indicate the SDK should back off. In severe conditions the SDK can temporarily pause recording entirely. + +Sampling configuration is resolved in two layers: -1. **Init Parameter** -2. **Environment Variable** (`TUSK_SAMPLING_RATE`) -3. **Configuration File** (`.tusk/config.yaml`) +1. **Base rate precedence** (highest to lowest): + - `TuskDrift.initialize(sampling_rate=...)` + - `TUSK_SAMPLING_RATE` + - `.tusk/config.yaml` `recording.sampling.base_rate` + - `.tusk/config.yaml` legacy `recording.sampling_rate` + - default base rate `1.0` +2. **Mode and minimum rate**: + - `recording.sampling.mode` comes from `.tusk/config.yaml` and defaults to `fixed` + - `recording.sampling.min_rate` is only used in `adaptive` mode and defaults to `0.001` when omitted -If not specified, the default sampling rate is `1.0` (100%). +> [!NOTE] +> Requests before `sdk.mark_app_as_ready()` are always recorded. Sampling applies to normal inbound traffic after startup. -### Method 1: Init Parameter (Programmatic Override) +### Method 1: Init Parameter (Programmatic Base-Rate Override) -Set the sampling rate directly in your initialization code: +Set the base sampling rate directly in your initialization code: ```python sdk = TuskDrift.initialize( api_key=os.environ.get("TUSK_API_KEY"), - sampling_rate=0.1, # 10% of requests + sampling_rate=0.1, # Base rate: 10% of requests ) ``` ### Method 2: Environment Variable -Set the `TUSK_SAMPLING_RATE` environment variable: +Set the `TUSK_SAMPLING_RATE` environment variable to override the base sampling rate: ```bash # Development - record everything @@ -205,17 +219,41 @@ TUSK_SAMPLING_RATE=0.1 python app.py ### Method 3: Configuration File -Update the configuration file `.tusk/config.yaml` to include a `recording` section: +Use the nested `recording.sampling` config to choose `fixed` vs `adaptive` mode and set the base/minimum rates. + +**Fixed sampling example:** ```yaml # ... existing configuration ... recording: - sampling_rate: 0.1 + sampling: + mode: fixed + base_rate: 0.1 export_spans: true enable_env_var_recording: true ``` +**Adaptive sampling example:** + +```yaml +# ... existing configuration ... + +recording: + sampling: + mode: adaptive + base_rate: 0.25 + min_rate: 0.01 + export_spans: true +``` + +**Legacy config still supported:** + +```yaml +recording: + sampling_rate: 0.1 +``` + ### Recording Configuration Options @@ -229,10 +267,28 @@ recording: - + + + + + + + - + + + + + + + + + + + + + diff --git a/docs/quickstart.md b/docs/quickstart.md index e25aff5..6549f09 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -2,9 +2,18 @@ Let's walk through recording and replaying your first trace: -## Step 1: Set sampling rate to 1.0 +## Step 1: Set fixed sampling to 1.0 -Set the `sampling_rate` in `.tusk/config.yaml` to 1.0 to ensure that all requests are recorded. +Update `.tusk/config.yaml` so your first recording captures every request: + +```yaml +recording: + sampling: + mode: fixed + base_rate: 1.0 +``` + +Legacy `recording.sampling_rate: 1.0` still works, but `recording.sampling` is the preferred config shape. ## Step 2: Start server in record mode From a99019e678e85ccf33fd170d923ad84ea43f7adb Mon Sep 17 00:00:00 2001 From: JY Tan Date: Fri, 10 Apr 2026 14:03:16 -0700 Subject: [PATCH 7/8] Standardize to use TUSK_RECORDING_SAMPLING_RATE --- docs/environment-variables.md | 8 +- docs/initialization.md | 15 ++-- drift/core/drift_sdk.py | 14 ++-- .../aiohttp/e2e-tests/docker-compose.yml | 1 + .../django/e2e-tests/docker-compose.yml | 1 + .../fastapi/e2e-tests/docker-compose.yml | 1 + .../flask/e2e-tests/docker-compose.yml | 1 + .../grpc/e2e-tests/docker-compose.yml | 1 + .../httpx/e2e-tests/docker-compose.yml | 1 + .../psycopg/e2e-tests/docker-compose.yml | 1 + .../psycopg2/e2e-tests/docker-compose.yml | 1 + .../redis/e2e-tests/docker-compose.yml | 1 + .../requests/e2e-tests/docker-compose.yml | 1 + .../urllib/e2e-tests/docker-compose.yml | 1 + .../urllib3/e2e-tests/docker-compose.yml | 1 + .../django-postgres/docker-compose.yml | 1 + .../django-redis/docker-compose.yml | 1 + .../fastapi-postgres/docker-compose.yml | 1 + tests/unit/test_drift_sdk.py | 73 +++++++++++++++---- tests/unit/test_sampling.py | 2 +- 20 files changed, 98 insertions(+), 29 deletions(-) diff --git a/docs/environment-variables.md b/docs/environment-variables.md index 4fa72be..f829229 100644 --- a/docs/environment-variables.md +++ b/docs/environment-variables.md @@ -110,7 +110,7 @@ Your Tusk Drift API key, required when using Tusk Cloud for storing and managing This will securely store your auth key for future replay sessions. -## TUSK_SAMPLING_RATE +## TUSK_RECORDING_SAMPLING_RATE Controls the base recording rate used during trace collection. @@ -123,14 +123,16 @@ Controls the base recording rate used during trace collection. ```bash # Record all requests (100%) -TUSK_SAMPLING_RATE=1.0 python app.py +TUSK_RECORDING_SAMPLING_RATE=1.0 python app.py # Record 10% of requests -TUSK_SAMPLING_RATE=0.1 python app.py +TUSK_RECORDING_SAMPLING_RATE=0.1 python app.py ``` If `recording.sampling.mode: adaptive` is enabled in `.tusk/config.yaml`, this environment variable still only changes the base rate; adaptive load shedding remains active. +`TUSK_RECORDING_SAMPLING_RATE` is the canonical variable, but `TUSK_SAMPLING_RATE` is still accepted as a backward-compatible alias. + For more details on sampling rate configuration methods and precedence, see the [Initialization Guide](./initialization.md#configure-sampling-rate). ## Rust Core Flags diff --git a/docs/initialization.md b/docs/initialization.md index 0414b79..c425f57 100644 --- a/docs/initialization.md +++ b/docs/initialization.md @@ -74,7 +74,7 @@ Create an initialization file or add the SDK initialization to your application - +
sampling_ratesampling.mode"fixed" | "adaptive""fixed"Selects constant sampling or adaptive load shedding.
sampling.base_rate float 1.0The sampling rate (0.0 - 1.0). 1.0 means 100% of requests are recorded, 0.0 means 0% of requests are recorded.The base sampling rate (0.0 - 1.0). This is the preferred config key and can be overridden by TUSK_SAMPLING_RATE or the sampling_rate init parameter.
sampling.min_ratefloat0.001 in adaptive modeThe minimum steady-state sampling rate for adaptive mode. In critical conditions the SDK can still temporarily pause recording.
sampling_ratefloatNoneLegacy fallback for the base sampling rate. Still supported for backward compatibility, but recording.sampling.base_rate is preferred.
export_spanssampling_rate float NoneOverride the base sampling rate (0.0 - 1.0) for recording. Takes precedence over TUSK_SAMPLING_RATE and config file base-rate settings. Does not change recording.sampling.mode.Override the base sampling rate (0.0 - 1.0) for recording. Takes precedence over TUSK_RECORDING_SAMPLING_RATE and config file base-rate settings. Does not change recording.sampling.mode.
@@ -183,7 +183,8 @@ Sampling configuration is resolved in two layers: 1. **Base rate precedence** (highest to lowest): - `TuskDrift.initialize(sampling_rate=...)` - - `TUSK_SAMPLING_RATE` + - `TUSK_RECORDING_SAMPLING_RATE` + - legacy alias `TUSK_SAMPLING_RATE` - `.tusk/config.yaml` `recording.sampling.base_rate` - `.tusk/config.yaml` legacy `recording.sampling_rate` - default base rate `1.0` @@ -207,16 +208,18 @@ sdk = TuskDrift.initialize( ### Method 2: Environment Variable -Set the `TUSK_SAMPLING_RATE` environment variable to override the base sampling rate: +Set the `TUSK_RECORDING_SAMPLING_RATE` environment variable to override the base sampling rate: ```bash # Development - record everything -TUSK_SAMPLING_RATE=1.0 python app.py +TUSK_RECORDING_SAMPLING_RATE=1.0 python app.py # Production - sample 10% of requests -TUSK_SAMPLING_RATE=0.1 python app.py +TUSK_RECORDING_SAMPLING_RATE=0.1 python app.py ``` +`TUSK_SAMPLING_RATE` is still supported as a backward-compatible alias, but new setups should prefer `TUSK_RECORDING_SAMPLING_RATE`. + ### Method 3: Configuration File Use the nested `recording.sampling` config to choose `fixed` vs `adaptive` mode and set the base/minimum rates. @@ -276,7 +279,7 @@ recording: sampling.base_rate float 1.0 - The base sampling rate (0.0 - 1.0). This is the preferred config key and can be overridden by TUSK_SAMPLING_RATE or the sampling_rate init parameter. + The base sampling rate (0.0 - 1.0). This is the preferred config key and can be overridden by TUSK_RECORDING_SAMPLING_RATE or the sampling_rate init parameter. sampling.min_rate diff --git a/drift/core/drift_sdk.py b/drift/core/drift_sdk.py index 14f7b90..2c58372 100644 --- a/drift/core/drift_sdk.py +++ b/drift/core/drift_sdk.py @@ -358,16 +358,20 @@ def _determine_sampling_config(self, init_param: float | None) -> ResolvedSampli base_rate = validated if base_rate is None: - env_rate = os.environ.get("TUSK_SAMPLING_RATE") - if env_rate is not None: + for env_key in ("TUSK_RECORDING_SAMPLING_RATE", "TUSK_SAMPLING_RATE"): + env_rate = os.environ.get(env_key) + if env_rate is None: + continue + try: parsed = float(env_rate) - validated = validate_sampling_rate(parsed, "TUSK_SAMPLING_RATE env var") + validated = validate_sampling_rate(parsed, f"{env_key} env var") if validated is not None: - logger.debug(f"Using sampling rate from env var: {validated}") + logger.debug(f"Using sampling rate from {env_key} env var: {validated}") base_rate = validated + break except ValueError: - logger.warning(f"Invalid TUSK_SAMPLING_RATE env var: {env_rate}") + logger.warning(f"Invalid {env_key} env var: {env_rate}") if base_rate is None and config_sampling and config_sampling.base_rate is not None: validated = validate_sampling_rate(config_sampling.base_rate, "config file recording.sampling.base_rate") diff --git a/drift/instrumentation/aiohttp/e2e-tests/docker-compose.yml b/drift/instrumentation/aiohttp/e2e-tests/docker-compose.yml index 14976cb..52a4ec4 100644 --- a/drift/instrumentation/aiohttp/e2e-tests/docker-compose.yml +++ b/drift/instrumentation/aiohttp/e2e-tests/docker-compose.yml @@ -27,6 +27,7 @@ services: - BENCHMARKS=${BENCHMARKS:-} - BENCHMARK_DURATION=${BENCHMARK_DURATION:-10} - BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3} + - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-} - TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-} - USE_MOCK_EXTERNALS=${USE_MOCK_EXTERNALS:-1} - MOCK_SERVER_BASE_URL=${MOCK_SERVER_BASE_URL:-http://mock-upstream:8081} diff --git a/drift/instrumentation/django/e2e-tests/docker-compose.yml b/drift/instrumentation/django/e2e-tests/docker-compose.yml index 801a3c7..72b9b09 100644 --- a/drift/instrumentation/django/e2e-tests/docker-compose.yml +++ b/drift/instrumentation/django/e2e-tests/docker-compose.yml @@ -28,6 +28,7 @@ services: - BENCHMARKS=${BENCHMARKS:-} - BENCHMARK_DURATION=${BENCHMARK_DURATION:-10} - BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3} + - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-} - TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-} - USE_MOCK_EXTERNALS=${USE_MOCK_EXTERNALS:-1} - MOCK_SERVER_BASE_URL=${MOCK_SERVER_BASE_URL:-http://mock-upstream:8081} diff --git a/drift/instrumentation/fastapi/e2e-tests/docker-compose.yml b/drift/instrumentation/fastapi/e2e-tests/docker-compose.yml index cf2e18c..9767b72 100644 --- a/drift/instrumentation/fastapi/e2e-tests/docker-compose.yml +++ b/drift/instrumentation/fastapi/e2e-tests/docker-compose.yml @@ -27,6 +27,7 @@ services: - BENCHMARKS=${BENCHMARKS:-} - BENCHMARK_DURATION=${BENCHMARK_DURATION:-10} - BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3} + - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-} - TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-} - USE_MOCK_EXTERNALS=${USE_MOCK_EXTERNALS:-1} - MOCK_SERVER_BASE_URL=${MOCK_SERVER_BASE_URL:-http://mock-upstream:8081} diff --git a/drift/instrumentation/flask/e2e-tests/docker-compose.yml b/drift/instrumentation/flask/e2e-tests/docker-compose.yml index 8c73754..5d21955 100644 --- a/drift/instrumentation/flask/e2e-tests/docker-compose.yml +++ b/drift/instrumentation/flask/e2e-tests/docker-compose.yml @@ -27,6 +27,7 @@ services: - BENCHMARKS=${BENCHMARKS:-} - BENCHMARK_DURATION=${BENCHMARK_DURATION:-10} - BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3} + - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-} - TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-} - USE_MOCK_EXTERNALS=${USE_MOCK_EXTERNALS:-1} - MOCK_SERVER_BASE_URL=${MOCK_SERVER_BASE_URL:-http://mock-upstream:8081} diff --git a/drift/instrumentation/grpc/e2e-tests/docker-compose.yml b/drift/instrumentation/grpc/e2e-tests/docker-compose.yml index c0d45a7..1671f4b 100644 --- a/drift/instrumentation/grpc/e2e-tests/docker-compose.yml +++ b/drift/instrumentation/grpc/e2e-tests/docker-compose.yml @@ -18,6 +18,7 @@ services: - BENCHMARKS=${BENCHMARKS:-} - BENCHMARK_DURATION=${BENCHMARK_DURATION:-10} - BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3} + - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-} - TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-} working_dir: /app volumes: diff --git a/drift/instrumentation/httpx/e2e-tests/docker-compose.yml b/drift/instrumentation/httpx/e2e-tests/docker-compose.yml index ae57669..6471395 100644 --- a/drift/instrumentation/httpx/e2e-tests/docker-compose.yml +++ b/drift/instrumentation/httpx/e2e-tests/docker-compose.yml @@ -27,6 +27,7 @@ services: - BENCHMARKS=${BENCHMARKS:-} - BENCHMARK_DURATION=${BENCHMARK_DURATION:-10} - BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3} + - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-} - TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-} - USE_MOCK_EXTERNALS=${USE_MOCK_EXTERNALS:-1} - MOCK_SERVER_BASE_URL=${MOCK_SERVER_BASE_URL:-http://mock-upstream:8081} diff --git a/drift/instrumentation/psycopg/e2e-tests/docker-compose.yml b/drift/instrumentation/psycopg/e2e-tests/docker-compose.yml index 34e6e0a..95a6b6a 100644 --- a/drift/instrumentation/psycopg/e2e-tests/docker-compose.yml +++ b/drift/instrumentation/psycopg/e2e-tests/docker-compose.yml @@ -38,6 +38,7 @@ services: - BENCHMARKS=${BENCHMARKS:-} - BENCHMARK_DURATION=${BENCHMARK_DURATION:-10} - BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3} + - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-} - TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-} working_dir: /app volumes: diff --git a/drift/instrumentation/psycopg2/e2e-tests/docker-compose.yml b/drift/instrumentation/psycopg2/e2e-tests/docker-compose.yml index 608fa98..0e10cc8 100644 --- a/drift/instrumentation/psycopg2/e2e-tests/docker-compose.yml +++ b/drift/instrumentation/psycopg2/e2e-tests/docker-compose.yml @@ -38,6 +38,7 @@ services: - BENCHMARKS=${BENCHMARKS:-} - BENCHMARK_DURATION=${BENCHMARK_DURATION:-10} - BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3} + - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-} - TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-} working_dir: /app volumes: diff --git a/drift/instrumentation/redis/e2e-tests/docker-compose.yml b/drift/instrumentation/redis/e2e-tests/docker-compose.yml index 84b269c..64772ea 100644 --- a/drift/instrumentation/redis/e2e-tests/docker-compose.yml +++ b/drift/instrumentation/redis/e2e-tests/docker-compose.yml @@ -31,6 +31,7 @@ services: - BENCHMARKS=${BENCHMARKS:-} - BENCHMARK_DURATION=${BENCHMARK_DURATION:-10} - BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3} + - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-} - TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-} working_dir: /app volumes: diff --git a/drift/instrumentation/requests/e2e-tests/docker-compose.yml b/drift/instrumentation/requests/e2e-tests/docker-compose.yml index 997da3a..0107d39 100644 --- a/drift/instrumentation/requests/e2e-tests/docker-compose.yml +++ b/drift/instrumentation/requests/e2e-tests/docker-compose.yml @@ -27,6 +27,7 @@ services: - BENCHMARKS=${BENCHMARKS:-} - BENCHMARK_DURATION=${BENCHMARK_DURATION:-10} - BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3} + - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-} - TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-} - USE_MOCK_EXTERNALS=${USE_MOCK_EXTERNALS:-1} - MOCK_SERVER_BASE_URL=${MOCK_SERVER_BASE_URL:-http://mock-upstream:8081} diff --git a/drift/instrumentation/urllib/e2e-tests/docker-compose.yml b/drift/instrumentation/urllib/e2e-tests/docker-compose.yml index 4beb9b9..1436145 100644 --- a/drift/instrumentation/urllib/e2e-tests/docker-compose.yml +++ b/drift/instrumentation/urllib/e2e-tests/docker-compose.yml @@ -27,6 +27,7 @@ services: - BENCHMARKS=${BENCHMARKS:-} - BENCHMARK_DURATION=${BENCHMARK_DURATION:-10} - BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3} + - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-} - TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-} - USE_MOCK_EXTERNALS=${USE_MOCK_EXTERNALS:-1} - MOCK_SERVER_BASE_URL=${MOCK_SERVER_BASE_URL:-http://mock-upstream:8081} diff --git a/drift/instrumentation/urllib3/e2e-tests/docker-compose.yml b/drift/instrumentation/urllib3/e2e-tests/docker-compose.yml index 10b39a7..c61e044 100644 --- a/drift/instrumentation/urllib3/e2e-tests/docker-compose.yml +++ b/drift/instrumentation/urllib3/e2e-tests/docker-compose.yml @@ -27,6 +27,7 @@ services: - BENCHMARKS=${BENCHMARKS:-} - BENCHMARK_DURATION=${BENCHMARK_DURATION:-10} - BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3} + - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-} - TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-} - USE_MOCK_EXTERNALS=${USE_MOCK_EXTERNALS:-1} - MOCK_SERVER_BASE_URL=${MOCK_SERVER_BASE_URL:-http://mock-upstream:8081} diff --git a/drift/stack-tests/django-postgres/docker-compose.yml b/drift/stack-tests/django-postgres/docker-compose.yml index c20e94f..6877e01 100644 --- a/drift/stack-tests/django-postgres/docker-compose.yml +++ b/drift/stack-tests/django-postgres/docker-compose.yml @@ -39,6 +39,7 @@ services: - BENCHMARKS=${BENCHMARKS:-} - BENCHMARK_DURATION=${BENCHMARK_DURATION:-10} - BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3} + - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-} - TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-} working_dir: /app volumes: diff --git a/drift/stack-tests/django-redis/docker-compose.yml b/drift/stack-tests/django-redis/docker-compose.yml index 1570b8a..e1c34db 100644 --- a/drift/stack-tests/django-redis/docker-compose.yml +++ b/drift/stack-tests/django-redis/docker-compose.yml @@ -32,6 +32,7 @@ services: - BENCHMARKS=${BENCHMARKS:-} - BENCHMARK_DURATION=${BENCHMARK_DURATION:-10} - BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3} + - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-} - TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-} working_dir: /app volumes: diff --git a/drift/stack-tests/fastapi-postgres/docker-compose.yml b/drift/stack-tests/fastapi-postgres/docker-compose.yml index 3497e5e..86ae38c 100644 --- a/drift/stack-tests/fastapi-postgres/docker-compose.yml +++ b/drift/stack-tests/fastapi-postgres/docker-compose.yml @@ -38,6 +38,7 @@ services: - BENCHMARKS=${BENCHMARKS:-} - BENCHMARK_DURATION=${BENCHMARK_DURATION:-10} - BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3} + - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-} - TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-} working_dir: /app volumes: diff --git a/tests/unit/test_drift_sdk.py b/tests/unit/test_drift_sdk.py index e7b1d3a..de7ccd9 100644 --- a/tests/unit/test_drift_sdk.py +++ b/tests/unit/test_drift_sdk.py @@ -20,7 +20,13 @@ def reset_singleton(self): TuskDrift._instance = None TuskDrift._initialized = False # Clear environment variables - env_vars = ["TUSK_DRIFT_MODE", "TUSK_API_KEY", "TUSK_SAMPLING_RATE", "ENV"] + env_vars = [ + "TUSK_DRIFT_MODE", + "TUSK_API_KEY", + "TUSK_RECORDING_SAMPLING_RATE", + "TUSK_SAMPLING_RATE", + "ENV", + ] original_env = {k: os.environ.get(k) for k in env_vars} for var in env_vars: if var in os.environ: @@ -121,9 +127,10 @@ def reset_singleton(self): """Reset singleton state before each test.""" TuskDrift._instance = None TuskDrift._initialized = False - # Clear sampling rate env var - if "TUSK_SAMPLING_RATE" in os.environ: - del os.environ["TUSK_SAMPLING_RATE"] + # Clear sampling rate env vars + for env_var in ("TUSK_RECORDING_SAMPLING_RATE", "TUSK_SAMPLING_RATE"): + if env_var in os.environ: + del os.environ[env_var] yield TuskDrift._instance = None TuskDrift._initialized = False @@ -137,10 +144,31 @@ def test_uses_init_param_sampling_rate(self, reset_singleton): assert result == 0.5 - def test_uses_env_var_sampling_rate(self, reset_singleton): - """Should use sampling rate from env var if init param not provided.""" + def test_uses_recording_env_var_sampling_rate(self, reset_singleton): + """Should use the canonical recording env var if init param not provided.""" os.environ["TUSK_DRIFT_MODE"] = "DISABLED" - os.environ["TUSK_SAMPLING_RATE"] = "0.25" + os.environ["TUSK_RECORDING_SAMPLING_RATE"] = "0.25" + instance = TuskDrift.get_instance() + + result = instance._determine_sampling_rate(None) + + assert result == 0.25 + + def test_uses_legacy_sampling_env_var_as_alias(self, reset_singleton): + """Should fall back to the legacy env var for backward compatibility.""" + os.environ["TUSK_DRIFT_MODE"] = "DISABLED" + os.environ["TUSK_SAMPLING_RATE"] = "0.2" + instance = TuskDrift.get_instance() + + result = instance._determine_sampling_rate(None) + + assert result == 0.2 + + def test_recording_env_var_takes_precedence_over_legacy_alias(self, reset_singleton): + """Should prefer the canonical env var when both env vars are set.""" + os.environ["TUSK_DRIFT_MODE"] = "DISABLED" + os.environ["TUSK_RECORDING_SAMPLING_RATE"] = "0.25" + os.environ["TUSK_SAMPLING_RATE"] = "0.1" instance = TuskDrift.get_instance() result = instance._determine_sampling_rate(None) @@ -150,7 +178,7 @@ def test_uses_env_var_sampling_rate(self, reset_singleton): def test_init_param_takes_precedence_over_env_var(self, reset_singleton): """Should prefer init param over env var.""" os.environ["TUSK_DRIFT_MODE"] = "DISABLED" - os.environ["TUSK_SAMPLING_RATE"] = "0.25" + os.environ["TUSK_RECORDING_SAMPLING_RATE"] = "0.25" instance = TuskDrift.get_instance() result = instance._determine_sampling_rate(0.75) @@ -160,7 +188,7 @@ def test_init_param_takes_precedence_over_env_var(self, reset_singleton): def test_invalid_init_param_falls_back_to_env_var(self, reset_singleton): """Should use env var when init param is present but invalid.""" os.environ["TUSK_DRIFT_MODE"] = "DISABLED" - os.environ["TUSK_SAMPLING_RATE"] = "0.25" + os.environ["TUSK_RECORDING_SAMPLING_RATE"] = "0.25" instance = TuskDrift.get_instance() result = instance._determine_sampling_rate(2.0) @@ -190,10 +218,10 @@ def test_defaults_to_1_0(self, reset_singleton): assert result == 1.0 - def test_rejects_invalid_env_var_sampling_rate(self, reset_singleton): - """Should reject invalid env var and use default.""" + def test_rejects_invalid_recording_env_var_sampling_rate(self, reset_singleton): + """Should reject an invalid canonical env var and use default.""" os.environ["TUSK_DRIFT_MODE"] = "DISABLED" - os.environ["TUSK_SAMPLING_RATE"] = "invalid" + os.environ["TUSK_RECORDING_SAMPLING_RATE"] = "invalid" instance = TuskDrift.get_instance() result = instance._determine_sampling_rate(None) @@ -203,7 +231,7 @@ def test_rejects_invalid_env_var_sampling_rate(self, reset_singleton): def test_invalid_env_var_falls_back_to_config_file(self, reset_singleton): """Should use config file when env var is present but invalid.""" os.environ["TUSK_DRIFT_MODE"] = "DISABLED" - os.environ["TUSK_SAMPLING_RATE"] = "invalid" + os.environ["TUSK_RECORDING_SAMPLING_RATE"] = "invalid" instance = TuskDrift.get_instance() instance.file_config = TuskFileConfig( recording=RecordingConfig( @@ -215,6 +243,17 @@ def test_invalid_env_var_falls_back_to_config_file(self, reset_singleton): assert result == 0.4 + def test_invalid_recording_env_var_falls_back_to_legacy_alias(self, reset_singleton): + """Should use the legacy alias when the canonical env var is invalid.""" + os.environ["TUSK_DRIFT_MODE"] = "DISABLED" + os.environ["TUSK_RECORDING_SAMPLING_RATE"] = "invalid" + os.environ["TUSK_SAMPLING_RATE"] = "0.4" + instance = TuskDrift.get_instance() + + result = instance._determine_sampling_rate(None) + + assert result == 0.4 + class TestTuskDriftInitialize: """Tests for TuskDrift.initialize method.""" @@ -225,7 +264,13 @@ def reset_singleton(self): TuskDrift._instance = None TuskDrift._initialized = False # Clear environment variables - env_vars = ["TUSK_DRIFT_MODE", "TUSK_API_KEY", "TUSK_SAMPLING_RATE", "ENV"] + env_vars = [ + "TUSK_DRIFT_MODE", + "TUSK_API_KEY", + "TUSK_RECORDING_SAMPLING_RATE", + "TUSK_SAMPLING_RATE", + "ENV", + ] for var in env_vars: if var in os.environ: del os.environ[var] diff --git a/tests/unit/test_sampling.py b/tests/unit/test_sampling.py index f6ee591..f166ec9 100644 --- a/tests/unit/test_sampling.py +++ b/tests/unit/test_sampling.py @@ -90,7 +90,7 @@ def test_rate_above_one_returns_none(self): def test_custom_source_in_warning(self): """Should include custom source in warning message.""" # Just verify it doesn't raise with custom source - result = validate_sampling_rate(-0.5, source="env var TUSK_SAMPLING_RATE") + result = validate_sampling_rate(-0.5, source="env var TUSK_RECORDING_SAMPLING_RATE") assert result is None def test_converts_to_float(self): From e495e8de86449526f3026dae4072325d409b9b4d Mon Sep 17 00:00:00 2001 From: JY Tan Date: Fri, 10 Apr 2026 14:17:51 -0700 Subject: [PATCH 8/8] Fix --- drift/core/drift_sdk.py | 10 ++++----- tests/unit/test_drift_sdk.py | 42 ++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 5 deletions(-) diff --git a/drift/core/drift_sdk.py b/drift/core/drift_sdk.py index 2c58372..b93ac2c 100644 --- a/drift/core/drift_sdk.py +++ b/drift/core/drift_sdk.py @@ -184,11 +184,6 @@ def initialize( "log_level": log_level, } - sampling_config = instance._determine_sampling_config(sampling_rate) - instance._sampling_rate = sampling_config.base_rate - instance._sampling_mode = sampling_config.mode - instance._min_sampling_rate = sampling_config.min_rate - effective_api_key = api_key or os.environ.get("TUSK_API_KEY") if not env: @@ -202,6 +197,11 @@ def initialize( logger.debug("Already initialized, skipping...") return instance + sampling_config = instance._determine_sampling_config(sampling_rate) + instance._sampling_rate = sampling_config.base_rate + instance._sampling_mode = sampling_config.mode + instance._min_sampling_rate = sampling_config.min_rate + # Start coverage collection after the _initialized guard so repeated # initialize() calls don't stop/restart coverage and lose accumulated data. from .coverage_server import start_coverage_collection diff --git a/tests/unit/test_drift_sdk.py b/tests/unit/test_drift_sdk.py index de7ccd9..adb16d0 100644 --- a/tests/unit/test_drift_sdk.py +++ b/tests/unit/test_drift_sdk.py @@ -6,6 +6,7 @@ import pytest +from drift.core.adaptive_sampling import AdaptiveSamplingController, ResolvedSamplingConfig from drift.core.config import RecordingConfig, SamplingConfig, TuskFileConfig from drift.core.drift_sdk import TuskDrift from drift.core.types import TuskDriftMode @@ -337,6 +338,47 @@ def test_idempotent_initialization(self, reset_singleton, mocker): # TracerProvider should only be created once assert mock_provider.call_count == 1 + def test_second_initialize_does_not_mutate_live_adaptive_sampling_state(self, reset_singleton, mocker): + """Should keep sampling fields aligned with the live controller on repeated initialize calls.""" + mocker.patch("drift.core.drift_sdk.install_hooks") + mocker.patch("drift.core.drift_sdk.atexit") + mocker.patch("drift.core.drift_sdk.TracerProvider") + mocker.patch("drift.core.drift_sdk.trace") + mocker.patch.object(TuskDrift, "_start_adaptive_sampling_control_loop") + os.environ["TUSK_DRIFT_MODE"] = "RECORD" + + instance = TuskDrift.get_instance() + instance.file_config = TuskFileConfig( + recording=RecordingConfig( + sampling=SamplingConfig(mode="adaptive", base_rate=0.5, min_rate=0.1), + ) + ) + + initialized_instance = TuskDrift.initialize(env="test") + initialized_instance._adaptive_sampling_controller = AdaptiveSamplingController( + ResolvedSamplingConfig(mode="adaptive", base_rate=0.5, min_rate=0.1), + random_fn=lambda: 0.0, + now_fn=lambda: 0.0, + ) + + initialized_instance.file_config = TuskFileConfig( + recording=RecordingConfig( + sampling=SamplingConfig(mode="fixed", base_rate=0.2, min_rate=None), + ) + ) + + second_instance = TuskDrift.initialize(env="test", sampling_rate=0.9) + + assert second_instance is initialized_instance + assert second_instance._sampling_rate == 0.5 + assert second_instance._sampling_mode == "adaptive" + assert second_instance._min_sampling_rate == 0.1 + + decision = second_instance.should_record_root_request(is_pre_app_start=False) + assert decision.mode == "adaptive" + assert decision.base_rate == 0.5 + assert decision.min_rate == 0.1 + class TestTuskDriftMarkAppAsReady: """Tests for TuskDrift.mark_app_as_ready method."""