From 4be0d91b90edbbacaa22714328c9b0e88c2a3669 Mon Sep 17 00:00:00 2001 From: Rodos Date: Fri, 3 Jul 2026 23:06:50 +1000 Subject: [PATCH 1/4] fix(circuit_breaker): make probe election per-thread and synchronize local counters The half-open probe owner id was minted once per process, so every thread in a process passed the owner check once any sibling won the election and all of them probed the recovering backend. The id is now a per-thread uuid (pid-aware for forked workers); the existing DynamoDB conditional write then elects one prober across threads and processes with no protocol change. The in-memory failure/success counters and observed-state map are now guarded by a lock, with threshold crossings detected atomically so a trip is persisted exactly once. Persistence settings are keyed per circuit name instead of living as shared instance attributes, and the local cache no longer raises into the protected call on concurrent expiry. --- .../utilities/circuit_breaker_alpha/base.py | 82 +++++--- .../circuit_breaker_alpha/persistence/base.py | 95 ++++++---- .../persistence/record.py | 3 +- docs/utilities/circuit_breaker.md | 11 ++ .../circuit_breaker_alpha/conftest.py | 102 +++++----- .../test_circuit_breaker.py | 179 ++++++++++++++++++ .../test_dynamodb_persistence.py | 1 - 7 files changed, 370 insertions(+), 103 deletions(-) diff --git a/aws_lambda_powertools/utilities/circuit_breaker_alpha/base.py b/aws_lambda_powertools/utilities/circuit_breaker_alpha/base.py index b4f66b98dc8..88fcedf61fe 100644 --- a/aws_lambda_powertools/utilities/circuit_breaker_alpha/base.py +++ b/aws_lambda_powertools/utilities/circuit_breaker_alpha/base.py @@ -10,6 +10,8 @@ import datetime import logging +import os +import threading import uuid from typing import TYPE_CHECKING, Any @@ -37,8 +39,31 @@ # recovered), so stale local failure streaks can be invalidated. _LAST_OBSERVED_STATE: dict[str, CircuitState] = {} -# Stable per-environment identifier used to claim the half-open probe lock. -_ENVIRONMENT_ID = uuid.uuid4().hex +# Guards the three dicts above. Increments are read-modify-write and a threshold +# crossing must be observed by exactly one thread, so every access goes through this +# lock. Held only while mutating the dicts, never across persistence writes or user +# callbacks. +_COUNTERS_LOCK = threading.Lock() + +# Identifier used to claim the half-open probe lock, unique per thread so the store's +# conditional election picks a single prober across threads as well as processes. +_PROBE_OWNER = threading.local() + + +def _probe_owner_id() -> str: + """ + Return this thread's stable probe-owner identifier, minting it on first use. + + A uuid in thread-local storage rather than ``threading.get_ident()``: the OS reuses + thread ids, and a recycled id would let an unrelated thread pass the owner check and + probe alongside the real owner. The pid check re-mints the id in forked children, + which inherit the forking thread's local storage. + """ + pid = os.getpid() + if getattr(_PROBE_OWNER, "pid", None) != pid: + _PROBE_OWNER.id = f"{pid}#{uuid.uuid4().hex}" + _PROBE_OWNER.pid = pid + return _PROBE_OWNER.id class CircuitBreakerHandler: @@ -111,32 +136,35 @@ def handle(self) -> Any: # If we previously observed a non-CLOSED state and the circuit is now back to # CLOSED, another environment completed the recovery cycle. Reset local counters # so a stale partial failure streak doesn't immediately re-trip the circuit. - prev = _LAST_OBSERVED_STATE.get(self.name) - if prev is not None and prev != CircuitState.CLOSED: - _LOCAL_FAILURES[self.name] = 0 - _LAST_OBSERVED_STATE[self.name] = CircuitState.CLOSED + with _COUNTERS_LOCK: + prev = _LAST_OBSERVED_STATE.get(self.name) + if prev is not None and prev != CircuitState.CLOSED: + _LOCAL_FAILURES[self.name] = 0 + _LAST_OBSERVED_STATE[self.name] = CircuitState.CLOSED return self._call_closed() if record.state == CircuitState.OPEN: - _LAST_OBSERVED_STATE[self.name] = CircuitState.OPEN + with _COUNTERS_LOCK: + _LAST_OBSERVED_STATE[self.name] = CircuitState.OPEN # ``opened_at`` may legitimately be 0 (epoch); treat only None as missing. opened_at = record.opened_at if record.opened_at is not None else self._now() if self._now() >= opened_at + self.config.recovery_timeout: # Recovery window elapsed: try to become the single prober. - if self.persistence_store.try_acquire_half_open(self.name, _ENVIRONMENT_ID, opened_at): + if self.persistence_store.try_acquire_half_open(self.name, _probe_owner_id(), opened_at): self._notify(CircuitState.OPEN, CircuitState.HALF_OPEN, opened_at=opened_at) return self._call_probe() return self._open_response(record.to_circuit_info()) - # HALF_OPEN: only the environment that owns the probe lock runs. - _LAST_OBSERVED_STATE[self.name] = CircuitState.HALF_OPEN - if record.half_open_owner == _ENVIRONMENT_ID: + # HALF_OPEN: only the thread that owns the probe lock runs. + with _COUNTERS_LOCK: + _LAST_OBSERVED_STATE[self.name] = CircuitState.HALF_OPEN + if record.half_open_owner == _probe_owner_id(): return self._call_probe() # If the probe lease has expired (owner recycled mid-probe), take over. if record.probe_lease_expiry is not None and self._now() >= record.probe_lease_expiry: logger.debug("Circuit '%s' probe lease expired; attempting takeover.", self.name) - if self.persistence_store.try_acquire_half_open(self.name, _ENVIRONMENT_ID, record.opened_at or 0): + if self.persistence_store.try_acquire_half_open(self.name, _probe_owner_id(), record.opened_at or 0): return self._call_probe() return self._open_response(record.to_circuit_info()) @@ -148,9 +176,14 @@ def _call_closed(self) -> Any: except Exception as exc: if not self.config.counts_as_failure(exc): raise - failures = _LOCAL_FAILURES.get(self.name, 0) + 1 - _LOCAL_FAILURES[self.name] = failures - if failures >= self.config.failure_threshold: + # Increment and reset atomically so exactly one thread observes the threshold + # crossing; racing threads would otherwise lose increments (tripping late) or + # each persist the same transition. + with _COUNTERS_LOCK: + failures = _LOCAL_FAILURES.get(self.name, 0) + 1 + tripped = failures >= self.config.failure_threshold + _LOCAL_FAILURES[self.name] = 0 if tripped else failures + if tripped: logger.debug("Circuit '%s' tripping CLOSED to OPEN after %d failures.", self.name, failures) opened_at = self._now() self._safe_persist( @@ -159,11 +192,11 @@ def _call_closed(self) -> Any: failure_count=failures, opened_at=opened_at, ) - _LOCAL_FAILURES[self.name] = 0 self._notify(CircuitState.CLOSED, CircuitState.OPEN, opened_at=opened_at) raise else: - _LOCAL_FAILURES[self.name] = 0 + with _COUNTERS_LOCK: + _LOCAL_FAILURES[self.name] = 0 return result def _call_probe(self) -> Any: @@ -176,17 +209,20 @@ def _call_probe(self) -> Any: logger.debug("Circuit '%s' probe failed; reopening.", self.name) opened_at = self._now() self._safe_persist(self.persistence_store.save_reopen, self.name, opened_at=opened_at) - _LOCAL_SUCCESSES[self.name] = 0 + with _COUNTERS_LOCK: + _LOCAL_SUCCESSES[self.name] = 0 self._notify(CircuitState.HALF_OPEN, CircuitState.OPEN, opened_at=opened_at) raise else: - successes = _LOCAL_SUCCESSES.get(self.name, 0) + 1 - _LOCAL_SUCCESSES[self.name] = successes - if successes >= self.config.success_threshold: + with _COUNTERS_LOCK: + successes = _LOCAL_SUCCESSES.get(self.name, 0) + 1 + closed = successes >= self.config.success_threshold + _LOCAL_SUCCESSES[self.name] = 0 if closed else successes + if closed: + _LOCAL_FAILURES[self.name] = 0 + if closed: logger.debug("Circuit '%s' closing after %d probe successes.", self.name, successes) self._safe_persist(self.persistence_store.save_closed, self.name) - _LOCAL_SUCCESSES[self.name] = 0 - _LOCAL_FAILURES[self.name] = 0 self._notify(CircuitState.HALF_OPEN, CircuitState.CLOSED) return result diff --git a/aws_lambda_powertools/utilities/circuit_breaker_alpha/persistence/base.py b/aws_lambda_powertools/utilities/circuit_breaker_alpha/persistence/base.py index f1dee4c745f..a0fb8b98e16 100644 --- a/aws_lambda_powertools/utilities/circuit_breaker_alpha/persistence/base.py +++ b/aws_lambda_powertools/utilities/circuit_breaker_alpha/persistence/base.py @@ -10,8 +10,9 @@ import datetime import logging +import threading from abc import ABC, abstractmethod -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, NamedTuple from aws_lambda_powertools.shared.cache_dict import LRUDict from aws_lambda_powertools.utilities.circuit_breaker_alpha.persistence.record import CircuitStateRecord @@ -32,6 +33,17 @@ PERSISTED_STATE_TTL_BUFFER = 3600 +class _CircuitSettings(NamedTuple): + """Per-circuit tunables the layer captures from :meth:`configure`.""" + + local_cache_max_age: int + recovery_timeout: int + + +# Fallback for direct layer use before configure() has run; mirrors CircuitBreakerConfig defaults. +_DEFAULT_SETTINGS = _CircuitSettings(local_cache_max_age=5, recovery_timeout=30) + + class CircuitBreakerExistingLockError(Exception): """Internal signal that a conditional half-open probe write lost the race.""" @@ -50,19 +62,26 @@ class CircuitBreakerPersistenceLayer(ABC): def __init__(self) -> None: """Initialize defaults; real configuration happens in :meth:`configure`.""" - self.circuit_name: str = "" - self.local_cache_max_age: int = 5 - self.recovery_timeout: int = 30 + # Per-circuit tunables, keyed by circuit name. One persistence instance is shared + # by every circuit (and thread) using the same store, so these must never live in + # plain instance attributes: circuits with different configs would stamp each + # other's TTL and probe lease. A plain dict, not an LRUDict: evicting a live + # circuit's settings would silently swap in the defaults (wrong lease and TTL), + # whereas evicting a cache entry below only costs a store re-read. + self._settings: dict[str, _CircuitSettings] = {} # Maps circuit name -> the unix timestamp the locally cached record goes stale. # Kept separate from the record's durable ``expiry_timestamp`` (the store TTL) so # the short in-memory freshness window is never mistaken for the long store TTL. self._cache: LRUDict = LRUDict(max_items=LOCAL_CACHE_MAX_ITEMS) + # One lock for both maps: LRUDict reorders entries even on reads, so unguarded + # concurrent access can corrupt it or raise. + self._lock = threading.Lock() def configure(self, config: CircuitBreakerConfig, circuit_name: str) -> None: """ - Bind the layer to a circuit and its configuration. + Bind a circuit's configuration to the layer. - Called once per invocation by the handler; the assignments are cheap and the + Called once per invocation by the handler; the assignment is cheap and the same persistence instance is reused across invocations within an environment. Parameters @@ -70,43 +89,54 @@ def configure(self, config: CircuitBreakerConfig, circuit_name: str) -> None: config : CircuitBreakerConfig Configuration providing the local cache TTL and recovery timeout. circuit_name : str - The circuit this layer instance serves. + The circuit these settings apply to. """ - self.circuit_name = circuit_name - self.local_cache_max_age = config.local_cache_max_age - self.recovery_timeout = config.recovery_timeout + with self._lock: + self._settings[circuit_name] = _CircuitSettings( + local_cache_max_age=config.local_cache_max_age, + recovery_timeout=config.recovery_timeout, + ) + + def _settings_for(self, name: str) -> _CircuitSettings: + """Return a circuit's configured settings, or the defaults if never configured.""" + with self._lock: + return self._settings.get(name) or _DEFAULT_SETTINGS # ------------------------------------------------------------------ cache def _cache_key(self, name: str) -> str: return name - def _durable_ttl(self) -> int: + def _durable_ttl(self, name: str) -> int: """ Compute the store TTL stamped on a persisted record. Sized to outlive a full recovery window so a live circuit is never reaped mid-cycle, while an abandoned circuit (no further writes) self-cleans soon after. """ - return int(datetime.datetime.now().timestamp()) + self.recovery_timeout + PERSISTED_STATE_TTL_BUFFER + now = int(datetime.datetime.now().timestamp()) + return now + self._settings_for(name).recovery_timeout + PERSISTED_STATE_TTL_BUFFER def _save_to_cache(self, record: CircuitStateRecord) -> None: """Cache a record locally with a short in-memory freshness window.""" - local_expiry = int(datetime.datetime.now().timestamp()) + self.local_cache_max_age - self._cache[self._cache_key(record.name)] = (local_expiry, record) + local_expiry = int(datetime.datetime.now().timestamp()) + self._settings_for(record.name).local_cache_max_age + with self._lock: + self._cache[self._cache_key(record.name)] = (local_expiry, record) def _retrieve_from_cache(self, name: str) -> CircuitStateRecord | None: """Return a cached record if present and still within its local freshness window.""" - cached = self._cache.get(self._cache_key(name)) - if cached is None: - return None + with self._lock: + cached = self._cache.get(self._cache_key(name)) + if cached is None: + return None - local_expiry, record = cached - if int(datetime.datetime.now().timestamp()) >= local_expiry: - del self._cache[self._cache_key(name)] - return None + local_expiry, record = cached + if int(datetime.datetime.now().timestamp()) >= local_expiry: + # pop, not del: this must never raise into the protected call. + self._cache.pop(self._cache_key(name), None) + return None - return record + return record # ------------------------------------------------------------- public API @@ -175,19 +205,19 @@ def save_open(self, name: str, failure_count: int, opened_at: int) -> None: state=CircuitState.OPEN, failure_count=failure_count, opened_at=opened_at, - expiry_timestamp=self._durable_ttl(), + expiry_timestamp=self._durable_ttl(name), ) self._put_record(record) self._save_to_cache(record) def try_acquire_half_open(self, name: str, owner: str, opened_at: int) -> bool: """ - Atomically elect a single environment to run the half-open probe. + Atomically elect a single worker to run the half-open probe. The conditional write succeeds only when the circuit is OPEN with no existing lock owner AND the ``opened_at`` matches what the caller observed (guards against stale eventually-consistent reads). A lease expiry is stamped so that if the - winning environment is recycled before completing the probe, others can take over + winning worker is recycled before completing the probe, others can take over once the lease lapses. Parameters @@ -195,25 +225,26 @@ def try_acquire_half_open(self, name: str, owner: str, opened_at: int) -> bool: name : str Circuit name. owner : str - Identifier of the environment attempting the probe. + Identifier of the worker (one thread in one execution environment) + attempting the probe. opened_at : int The ``opened_at`` the caller observed, kept stable across the transition. Returns ------- bool - ``True`` if this environment won the probe lock, ``False`` if another - environment already holds it. + ``True`` if this worker won the probe lock, ``False`` if another + worker already holds it. """ # Lease = recovery_timeout gives the probe a full cycle to complete. - probe_lease_expiry = int(datetime.datetime.now().timestamp()) + self.recovery_timeout + probe_lease_expiry = int(datetime.datetime.now().timestamp()) + self._settings_for(name).recovery_timeout record = CircuitStateRecord( name=name, state=CircuitState.HALF_OPEN, opened_at=opened_at, half_open_owner=owner, probe_lease_expiry=probe_lease_expiry, - expiry_timestamp=self._durable_ttl(), + expiry_timestamp=self._durable_ttl(name), ) try: self._put_record(record, condition="half_open", expected_opened_at=opened_at) @@ -228,7 +259,7 @@ def save_closed(self, name: str) -> None: name=name, state=CircuitState.CLOSED, failure_count=0, - expiry_timestamp=self._durable_ttl(), + expiry_timestamp=self._durable_ttl(name), ) self._update_record(record) self._save_to_cache(record) @@ -245,7 +276,7 @@ def save_reopen(self, name: str, opened_at: int) -> None: name=name, state=CircuitState.OPEN, opened_at=opened_at, - expiry_timestamp=self._durable_ttl(), + expiry_timestamp=self._durable_ttl(name), ) self._update_record(record) self._save_to_cache(record) diff --git a/aws_lambda_powertools/utilities/circuit_breaker_alpha/persistence/record.py b/aws_lambda_powertools/utilities/circuit_breaker_alpha/persistence/record.py index 71086c3c1a7..ec48737810f 100644 --- a/aws_lambda_powertools/utilities/circuit_breaker_alpha/persistence/record.py +++ b/aws_lambda_powertools/utilities/circuit_breaker_alpha/persistence/record.py @@ -30,7 +30,8 @@ class CircuitStateRecord: Unix timestamp (seconds) the circuit opened. Anchors the recovery timeout; ``None`` while closed. half_open_owner : str | None - Identifier of the execution environment that won the half-open probe lock, if any. + Identifier of the worker (one thread in one execution environment) that won the + half-open probe lock, if any. expiry_timestamp : int | None Unix timestamp (seconds) for the store's TTL attribute. """ diff --git a/docs/utilities/circuit_breaker.md b/docs/utilities/circuit_breaker.md index f83685862f0..aa524d0234a 100644 --- a/docs/utilities/circuit_breaker.md +++ b/docs/utilities/circuit_breaker.md @@ -20,6 +20,7 @@ The circuit breaker utility stops sending traffic to an unhealthy downstream dep * Hands rejected requests to an `on_circuit_open` callback so you decide what happens next (buffer, drop, return a cached value) * Tests recovery with an explicit half-open probe rather than blindly retrying everything at once * Shares circuit state across execution environments via Amazon DynamoDB +* Safe under concurrency: one half-open probe across all threads and execution environments, and synchronized failure counting * Keeps the healthy path write-free: failures are counted in memory and only persisted on a state transition ## Terminology @@ -182,6 +183,16 @@ Passing both raises `CircuitBreakerConfigError`. An exception that doesn't count After `recovery_timeout` seconds, the circuit moves to `HALF_OPEN` and elects a **single** execution environment (via a conditional DynamoDB write) to run a probe. If `success_threshold` consecutive probes succeed, the circuit closes; a single failing probe reopens it. This stops a thundering herd of every environment hammering a recovering backend at once. +!!! note "Thread safety" + The utility is safe to share across threads: within a multi-threaded environment the probe election picks a single + thread, so the single-prober guarantee spans threads as well as environments, and the in-memory failure counter is + synchronized. Single-threaded functions (the normal Lambda model) are unaffected. `on_circuit_open` and + `on_transition` hooks may run concurrently from multiple threads. + + Probe ownership belongs to the thread that won the election. If that thread never runs the circuit again (for + example, a thread-per-request worker pool), recovery waits for the probe lease to expire before another thread or + environment takes over. + ### State coordination across environments The consecutive-failure counter lives in memory per execution environment, so a healthy circuit performs **no writes**. Only when an environment reaches `failure_threshold` does it persist `OPEN`. The shared state is cached locally for `local_cache_max_age` seconds to avoid a read per invocation. A cache miss (cold start or expired entry) forces a read-through before routing. diff --git a/tests/functional/circuit_breaker_alpha/conftest.py b/tests/functional/circuit_breaker_alpha/conftest.py index e2c9fc4d0d8..297c088750b 100644 --- a/tests/functional/circuit_breaker_alpha/conftest.py +++ b/tests/functional/circuit_breaker_alpha/conftest.py @@ -1,5 +1,7 @@ from __future__ import annotations +import threading + import pytest import aws_lambda_powertools.utilities.circuit_breaker_alpha.base as base_module @@ -15,23 +17,29 @@ class FakePersistence(CircuitBreakerPersistenceLayer): """In-memory store for exercising the handler state machine without DynamoDB.""" + # Each DynamoDB operation is atomic, including the conditional put; the fake must be + # too, or threaded tests race inside the fake instead of the code under test. + # Class-level so store instances sharing a db (multi-environment tests) share it. + _db_lock = threading.Lock() + def __init__(self): self.db: dict[str, CircuitStateRecord] = {} super().__init__() def _get_record(self, name: str) -> CircuitStateRecord | None: - if name not in self.db: - return None - stored = self.db[name] - # Return a copy so the handler can't mutate stored state by reference. - return CircuitStateRecord( - name=stored.name, - state=stored.state, - failure_count=stored.failure_count, - opened_at=stored.opened_at, - half_open_owner=stored.half_open_owner, - probe_lease_expiry=stored.probe_lease_expiry, - ) + with self._db_lock: + if name not in self.db: + return None + stored = self.db[name] + # Return a copy so the handler can't mutate stored state by reference. + return CircuitStateRecord( + name=stored.name, + state=stored.state, + failure_count=stored.failure_count, + opened_at=stored.opened_at, + half_open_owner=stored.half_open_owner, + probe_lease_expiry=stored.probe_lease_expiry, + ) def _put_record( self, @@ -39,46 +47,48 @@ def _put_record( condition: str | None = None, expected_opened_at: int | None = None, ) -> None: - if condition == "half_open": - existing = self.db.get(record.name) - now = CircuitBreakerHandler._now() - - # Mirror the DynamoDB condition: two valid paths - # Path 1: state=OPEN AND no owner (AND opened_at matches if provided) - # Path 2: state=HALF_OPEN AND probe_lease_expiry <= now (lease takeover) - fresh_election_ok = existing is None or ( - existing.state == CircuitState.OPEN - and existing.half_open_owner is None - and (expected_opened_at is None or existing.opened_at == expected_opened_at) - ) - lease_takeover_ok = ( - existing is not None - and existing.state == CircuitState.HALF_OPEN - and existing.probe_lease_expiry is not None - and now >= existing.probe_lease_expiry - ) - - if not fresh_election_ok and not lease_takeover_ok: - raise CircuitBreakerExistingLockError - self.db[record.name] = record + with self._db_lock: + if condition == "half_open": + existing = self.db.get(record.name) + now = CircuitBreakerHandler._now() + + # Mirror the DynamoDB condition: two valid paths + # Path 1: state=OPEN AND no owner (AND opened_at matches if provided) + # Path 2: state=HALF_OPEN AND probe_lease_expiry <= now (lease takeover) + fresh_election_ok = existing is None or ( + existing.state == CircuitState.OPEN + and existing.half_open_owner is None + and (expected_opened_at is None or existing.opened_at == expected_opened_at) + ) + lease_takeover_ok = ( + existing is not None + and existing.state == CircuitState.HALF_OPEN + and existing.probe_lease_expiry is not None + and now >= existing.probe_lease_expiry + ) + + if not fresh_election_ok and not lease_takeover_ok: + raise CircuitBreakerExistingLockError + self.db[record.name] = record def _update_record(self, record: CircuitStateRecord) -> None: # Mirror DynamoDB UpdateItem semantics: a partial merge driven by which # attributes the backend actually writes, NOT a wholesale replace. This is # what exposes attributes the update path forgets to clear (e.g. a stale # half_open_owner left behind on reopen). - existing = self.db.get(record.name) - if existing is None: - self.db[record.name] = record - return - existing.state = record.state - existing.failure_count = record.failure_count - existing.expiry_timestamp = record.expiry_timestamp - # Leaving HALF_OPEN (close or reopen) always releases the probe-owner lock and - # probe lease; only opened_at differs between the two transitions. - existing.half_open_owner = None - existing.probe_lease_expiry = None - existing.opened_at = record.opened_at + with self._db_lock: + existing = self.db.get(record.name) + if existing is None: + self.db[record.name] = record + return + existing.state = record.state + existing.failure_count = record.failure_count + existing.expiry_timestamp = record.expiry_timestamp + # Leaving HALF_OPEN (close or reopen) always releases the probe-owner lock and + # probe lease; only opened_at differs between the two transitions. + existing.half_open_owner = None + existing.probe_lease_expiry = None + existing.opened_at = record.opened_at @pytest.fixture diff --git a/tests/functional/circuit_breaker_alpha/test_circuit_breaker.py b/tests/functional/circuit_breaker_alpha/test_circuit_breaker.py index 1227291ee66..911298f144b 100644 --- a/tests/functional/circuit_breaker_alpha/test_circuit_breaker.py +++ b/tests/functional/circuit_breaker_alpha/test_circuit_breaker.py @@ -1,5 +1,6 @@ from __future__ import annotations +import threading import warnings import pytest @@ -602,3 +603,181 @@ def test_error_with_details_formatting(): """Covers exceptions.py line 28 — __str__ with details.""" err = CircuitBreakerError("main message", "extra detail") assert str(err) == "main message - (extra detail)" + + +# --------------------------------------------------------------------------- thread safety (#8320) + + +def _run_in_threads(worker, count): + threads = [threading.Thread(target=worker) for _ in range(count)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + +def test_threads_elect_a_single_prober_on_recovery(store, now): + """Regression for #8320: after one thread wins the election, siblings must not also probe.""" + store.db["c"] = CircuitStateRecord(name="c", state=CircuitState.OPEN, failure_count=5, opened_at=now - 60) + # success_threshold=2 keeps the circuit HALF_OPEN after the winning probe, so a + # late-reading loser can never legitimately run the function through a closed circuit. + config = CircuitBreakerConfig(recovery_timeout=30, success_threshold=2, local_cache_max_age=0) + probes = [] + open_responses = [] + barrier = threading.Barrier(8) + + @circuit_breaker(name="c", persistence_store=store, config=config) + def call(): + probes.append(threading.current_thread().name) + return "ok" + + def worker(): + barrier.wait(timeout=10) + try: + call() + except CircuitBreakerOpenError: + open_responses.append(threading.current_thread().name) + + _run_in_threads(worker, 8) + + assert len(probes) == 1, "exactly one thread must probe per recovery window" + assert len(open_responses) == 7, "every other thread must get the open response" + assert store.db["c"].state == CircuitState.HALF_OPEN + assert store.db["c"].half_open_owner is not None + + +def test_threads_trip_at_exactly_the_failure_threshold(store): + """Racing increments must neither lose updates (tripping late) nor persist the trip twice.""" + threads_count, threshold = 8, 5 + config = CircuitBreakerConfig(failure_threshold=threshold, local_cache_max_age=0) + barrier = threading.Barrier(threads_count) + siblings_done = threading.Event() + done_lock = threading.Lock() + finished = [] + save_open_counts = [] + original_save_open = store.save_open + + def blocking_save_open(name, failure_count, opened_at): + save_open_counts.append(failure_count) + if len(save_open_counts) > 1: + # A second persist is the bug itself; unblock immediately so it fails fast. + siblings_done.set() + # Park the tripping thread mid-persist until every sibling has recorded its + # failure. Code that resets the counter only after persisting lets the + # remaining threads cross the threshold too and persist the trip again. + siblings_done.wait(timeout=10) + original_save_open(name, failure_count=failure_count, opened_at=opened_at) + + store.save_open = blocking_save_open + + @circuit_breaker(name="c", persistence_store=store, config=config) + def call(): + # Hold every thread inside the protected call so all pass the CLOSED check + # before any failure is recorded, forcing the increments to race. + barrier.wait(timeout=10) + raise ConnectionError("downstream down") + + raised = [] + + def worker(): + try: + call() + except ConnectionError: + raised.append(threading.current_thread().name) + finally: + with done_lock: + finished.append(threading.current_thread().name) + if len(finished) == threads_count - 1: + # Only the tripping thread is still parked in the persist spy. + siblings_done.set() + + _run_in_threads(worker, threads_count) + + assert save_open_counts == [threshold], "the trip must be persisted exactly once, at the threshold" + assert store.db["c"].state == CircuitState.OPEN + assert len(raised) == threads_count + assert base_module._LOCAL_FAILURES["c"] == threads_count - threshold, "post-trip failures restart the streak" + + +def test_probe_ownership_is_per_thread_and_stable_across_invocations(store, now): + store.db["c"] = CircuitStateRecord(name="c", state=CircuitState.OPEN, failure_count=5, opened_at=now - 60) + config = CircuitBreakerConfig(recovery_timeout=30, success_threshold=2, local_cache_max_age=0) + calls = [] + + @circuit_breaker(name="c", persistence_store=store, config=config) + def call(): + calls.append(threading.current_thread().name) + return "ok" + + # This thread wins the election; one more probe success is needed to close. + assert call() == "ok" + assert store.db["c"].state == CircuitState.HALF_OPEN + + # A sibling thread must not inherit ownership (it did when the id was per-process). + sibling_outcome = [] + + def sibling(): + try: + sibling_outcome.append(call()) + except CircuitBreakerOpenError: + sibling_outcome.append("rejected") + + _run_in_threads(sibling, 1) + assert sibling_outcome == ["rejected"] + + # The owner's identity must survive across invocations so it can finish the recovery; + # a per-handler or per-invocation id would strand the circuit until the lease expired. + assert call() == "ok" + assert store.db["c"].state == CircuitState.CLOSED + assert len(calls) == 2 + + +def test_single_probe_spans_environments_and_threads(store, now): + """The same election covers threads in one process and separate environments.""" + other_env = type(store)() + other_env.db = store.db + store.db["c"] = CircuitStateRecord(name="c", state=CircuitState.OPEN, failure_count=5, opened_at=now - 60) + config = CircuitBreakerConfig(recovery_timeout=30, success_threshold=2, local_cache_max_age=0) + probes = [] + open_responses = [] + barrier = threading.Barrier(8) + + def protect(persistence): + @circuit_breaker(name="c", persistence_store=persistence, config=config) + def call(): + probes.append(threading.current_thread().name) + return "ok" + + return call + + calls = [protect(store), protect(other_env)] + + def worker(index): + env_call = calls[index % 2] + barrier.wait(timeout=10) + try: + env_call() + except CircuitBreakerOpenError: + open_responses.append(threading.current_thread().name) + + threads = [threading.Thread(target=worker, args=(index,)) for index in range(8)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + assert len(probes) == 1, "one probe across all threads of all environments" + assert len(open_responses) == 7 + + +def test_settings_are_kept_per_circuit_on_a_shared_layer(store, now): + """Configuring another circuit on the same layer must not stamp this one's probe lease.""" + store.configure(CircuitBreakerConfig(recovery_timeout=30, local_cache_max_age=0), circuit_name="a") + store.configure(CircuitBreakerConfig(recovery_timeout=9999, local_cache_max_age=0), circuit_name="b") + store.db["a"] = CircuitStateRecord(name="a", state=CircuitState.OPEN, failure_count=5, opened_at=now - 60) + + assert store.try_acquire_half_open("a", "owner-1", now - 60) + + lease = store.db["a"].probe_lease_expiry + assert lease is not None + assert lease <= now + 30 + 2, "the lease must derive from circuit 'a' recovery_timeout, not 'b'" diff --git a/tests/functional/circuit_breaker_alpha/test_dynamodb_persistence.py b/tests/functional/circuit_breaker_alpha/test_dynamodb_persistence.py index 4555c088e36..f0dd88e1446 100644 --- a/tests/functional/circuit_breaker_alpha/test_dynamodb_persistence.py +++ b/tests/functional/circuit_breaker_alpha/test_dynamodb_persistence.py @@ -137,7 +137,6 @@ def test_save_open_item_contains_expiration_attribute(persistence): # the documented self-cleaning of abandoned circuits never happens. Capture the # actual PutItem params rather than asserting an exact (time-dependent) value. captured = {} - persistence.local_cache_max_age = 5 original_put = persistence.client.put_item From 36e29e4dedce5c5880aa4afe6e9126ae6f82435a Mon Sep 17 00:00:00 2001 From: Rodos Date: Sat, 4 Jul 2026 22:46:43 +1000 Subject: [PATCH 2/4] test(circuit_breaker): cover losing the expired-lease probe takeover election The takeover's failure arm was untested: when another environment wins the conditional election, the caller must get the open-circuit response and the protected function must not run. --- .../test_circuit_breaker.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/functional/circuit_breaker_alpha/test_circuit_breaker.py b/tests/functional/circuit_breaker_alpha/test_circuit_breaker.py index 911298f144b..78dece09f86 100644 --- a/tests/functional/circuit_breaker_alpha/test_circuit_breaker.py +++ b/tests/functional/circuit_breaker_alpha/test_circuit_breaker.py @@ -530,6 +530,40 @@ def call(): call() +def test_half_open_expired_lease_lost_takeover_returns_open_response(store, now): + """Losing the takeover election must NOT probe: another thread/environment won + the conditional write between our expiry check and our acquire, so the call + has to get the open-circuit response — electing every candidate is exactly + the multi-prober bug the per-thread owner id exists to prevent.""" + config = CircuitBreakerConfig(recovery_timeout=30, success_threshold=1, local_cache_max_age=0) + + # A stranded probe with an expired lease, owned by someone else. + store.db["c"] = CircuitStateRecord( + name="c", + state=CircuitState.HALF_OPEN, + opened_at=now - 200, + half_open_owner="dead-env", + probe_lease_expiry=now - 10, # expired — takeover will be attempted + ) + + # The race's loser: the conditional election fails. + def lost_election(name, owner_id, opened_at): + return False + + store.try_acquire_half_open = lost_election + + protected_ran = {"value": False} + + @circuit_breaker(name="c", persistence_store=store, config=config) + def call(): + protected_ran["value"] = True + return "must not probe" + + with pytest.raises(CircuitBreakerOpenError): + call() + assert protected_ran["value"] is False + + def test_open_lost_election_returns_open_response(store, now): """Branch: try_acquire_half_open returns False (another env won the race).""" config = CircuitBreakerConfig(recovery_timeout=30, success_threshold=1, local_cache_max_age=0) From be27f1b1848511495060831463a2d606e4098182 Mon Sep 17 00:00:00 2001 From: Rodos Date: Sun, 5 Jul 2026 07:08:59 +1000 Subject: [PATCH 3/4] fix(circuit_breaker): address review feedback - Make the counter race test catch a missing lock: park readers mid-increment so an unsynchronized read-modify-write deterministically loses updates (fails 10/10 unlocked, passes 10/10 locked) - Document that on_circuit_open and on_transition callbacks can run concurrently and must be thread-safe - Use dict.get's default instead of `or` for the per-circuit settings lookup --- .../circuit_breaker_alpha/persistence/base.py | 2 +- docs/utilities/circuit_breaker.md | 7 +++++-- .../test_circuit_breaker.py | 21 ++++++++++++++++++- 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/aws_lambda_powertools/utilities/circuit_breaker_alpha/persistence/base.py b/aws_lambda_powertools/utilities/circuit_breaker_alpha/persistence/base.py index a0fb8b98e16..3cc36894bd3 100644 --- a/aws_lambda_powertools/utilities/circuit_breaker_alpha/persistence/base.py +++ b/aws_lambda_powertools/utilities/circuit_breaker_alpha/persistence/base.py @@ -100,7 +100,7 @@ def configure(self, config: CircuitBreakerConfig, circuit_name: str) -> None: def _settings_for(self, name: str) -> _CircuitSettings: """Return a circuit's configured settings, or the defaults if never configured.""" with self._lock: - return self._settings.get(name) or _DEFAULT_SETTINGS + return self._settings.get(name, _DEFAULT_SETTINGS) # ------------------------------------------------------------------ cache diff --git a/docs/utilities/circuit_breaker.md b/docs/utilities/circuit_breaker.md index aa524d0234a..b2117f7316b 100644 --- a/docs/utilities/circuit_breaker.md +++ b/docs/utilities/circuit_breaker.md @@ -186,13 +186,16 @@ After `recovery_timeout` seconds, the circuit moves to `HALF_OPEN` and elects a !!! note "Thread safety" The utility is safe to share across threads: within a multi-threaded environment the probe election picks a single thread, so the single-prober guarantee spans threads as well as environments, and the in-memory failure counter is - synchronized. Single-threaded functions (the normal Lambda model) are unaffected. `on_circuit_open` and - `on_transition` hooks may run concurrently from multiple threads. + synchronized. Single-threaded functions (the normal Lambda model) are unaffected. Probe ownership belongs to the thread that won the election. If that thread never runs the circuit again (for example, a thread-per-request worker pool), recovery waits for the probe lease to expire before another thread or environment takes over. +!!! warning "Make your hooks thread-safe" + If your function runs multiple threads, `on_circuit_open` and `on_transition` callbacks can run concurrently + for the same circuit. Make them thread-safe. + ### State coordination across environments The consecutive-failure counter lives in memory per execution environment, so a healthy circuit performs **no writes**. Only when an environment reaches `failure_threshold` does it persist `OPEN`. The shared state is cached locally for `local_cache_max_age` seconds to avoid a read per invocation. A cache miss (cold start or expired entry) forces a read-through before routing. diff --git a/tests/functional/circuit_breaker_alpha/test_circuit_breaker.py b/tests/functional/circuit_breaker_alpha/test_circuit_breaker.py index 78dece09f86..8a6daf7c17f 100644 --- a/tests/functional/circuit_breaker_alpha/test_circuit_breaker.py +++ b/tests/functional/circuit_breaker_alpha/test_circuit_breaker.py @@ -1,6 +1,7 @@ from __future__ import annotations import threading +import time import warnings import pytest @@ -680,8 +681,26 @@ def worker(): assert store.db["c"].half_open_owner is not None -def test_threads_trip_at_exactly_the_failure_threshold(store): +class _RacyCounters(dict): + """Counter dict that yields between the read and the write of an increment. + + Without this, the GIL makes the unlocked read-modify-write effectively atomic + (the interpreter rarely switches threads inside those few bytecodes), so a + missing lock would still pass. Holding every reader mid-increment long enough + for its siblings to read the same stale value makes the lost update + deterministic: unlocked, all threads write back the same count and the + threshold is never reached. + """ + + def get(self, key, default=None): + value = super().get(key, default) + time.sleep(0.005) # park mid-increment so sibling threads read the same stale value + return value + + +def test_threads_trip_at_exactly_the_failure_threshold(store, monkeypatch): """Racing increments must neither lose updates (tripping late) nor persist the trip twice.""" + monkeypatch.setattr(base_module, "_LOCAL_FAILURES", _RacyCounters()) threads_count, threshold = 8, 5 config = CircuitBreakerConfig(failure_threshold=threshold, local_cache_max_age=0) barrier = threading.Barrier(threads_count) From f194294fae1eed0876d46939c413cad891d279d3 Mon Sep 17 00:00:00 2001 From: Rodos Date: Sun, 5 Jul 2026 20:56:51 +1000 Subject: [PATCH 4/4] fix(circuit_breaker): avoid LRUDict.pop when evicting expired cache entries MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - On Python 3.10, OrderedDict.pop re-enters the subclass __getitem__ after detaching the linked-list node, so LRUDict.pop raises KeyError for a present key and corrupts the dict (fixed in CPython 3.11) — this failed 14 tests on the 3.10 CI job only - Evict with a guarded del instead, which never calls subclass hooks; the surrounding lock already makes the get-then-delete atomic, and the try/except keeps the never-raise-into-the-protected-call contract explicit --- .../utilities/circuit_breaker_alpha/persistence/base.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/aws_lambda_powertools/utilities/circuit_breaker_alpha/persistence/base.py b/aws_lambda_powertools/utilities/circuit_breaker_alpha/persistence/base.py index 3cc36894bd3..a4ca7e1985b 100644 --- a/aws_lambda_powertools/utilities/circuit_breaker_alpha/persistence/base.py +++ b/aws_lambda_powertools/utilities/circuit_breaker_alpha/persistence/base.py @@ -132,8 +132,13 @@ def _retrieve_from_cache(self, name: str) -> CircuitStateRecord | None: local_expiry, record = cached if int(datetime.datetime.now().timestamp()) >= local_expiry: - # pop, not del: this must never raise into the protected call. - self._cache.pop(self._cache_key(name), None) + # Guarded del, not pop: on Python 3.10 OrderedDict.pop re-enters the + # subclass __getitem__ after detaching the node, so LRUDict.pop raises + # KeyError for a *present* key and corrupts the dict (fixed in 3.11). + try: + del self._cache[self._cache_key(name)] + except KeyError: + pass return None return record