-
Notifications
You must be signed in to change notification settings - Fork 0
MPT-21532 Add account-scoped authentication provider #350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,15 @@ | ||
| from mpt_api_client.auth.base import Authentication, BearerTokenAuthentication | ||
| from mpt_api_client.auth.account_scoped import AccountScopedAuthentication | ||
| from mpt_api_client.auth.base import ( | ||
| Authentication, | ||
| BearerTokenAuthentication, | ||
| InstallationTokenAuthentication, | ||
| ) | ||
| from mpt_api_client.auth.extension_framework import ExtensionFrameworkAuthentication | ||
|
|
||
| __all__ = [ # noqa: WPS410 | ||
| "AccountScopedAuthentication", | ||
| "Authentication", | ||
| "BearerTokenAuthentication", | ||
| "ExtensionFrameworkAuthentication", | ||
| "InstallationTokenAuthentication", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,171 @@ | ||
| """Account-scoped authentication for the MPT integration API. | ||
|
|
||
| This provider fetches account-scoped installation tokens and shares them across instances | ||
| through a process-wide cache keyed by ``(secret, account_id)``. Token fetches are serialized | ||
| per account, so concurrent callers for the same account trigger at most one token request. | ||
| """ | ||
|
|
||
| import asyncio | ||
| import datetime as dt | ||
| import threading | ||
| from collections.abc import AsyncGenerator, Generator | ||
| from dataclasses import dataclass | ||
| from typing import ClassVar, override | ||
|
|
||
| import httpx | ||
|
|
||
| from mpt_api_client.auth.base import InstallationTokenAuthentication | ||
| from mpt_api_client.exceptions import MPTError | ||
|
|
||
| DEFAULT_TOKEN_VALIDITY_LEEWAY_SECONDS = 60 | ||
|
|
||
| CacheKey = tuple[str, str] | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class _CachedToken: | ||
| """A cached account token together with its decoded expiry.""" | ||
|
|
||
| token: str | ||
| expires_at: dt.datetime | None | ||
|
|
||
|
|
||
| class AccountScopedAuthentication(InstallationTokenAuthentication): # noqa: WPS214 | ||
| """Authenticate with an account-scoped token from a shared, concurrency-safe cache. | ||
|
|
||
| Tokens are cached process-wide keyed by ``(secret, account_id)``, so several provider or | ||
| client instances for the same account reuse a single token. Refresh is serialized per | ||
| account through a lock with double-checked caching: concurrent callers for the same | ||
| account trigger at most one token request. Refresh happens proactively once the token is | ||
| within ``min_remaining_validity_seconds`` of its JWT ``exp`` claim, with a reactive | ||
| refresh on ``401 Unauthorized`` for tokens revoked before they expire. When the fetched | ||
| token carries no readable ``exp`` claim, proactive refresh is skipped and only the | ||
| reactive ``401`` path applies. | ||
|
|
||
| The token call is delegated to :class:`InstallationsTokenService` (and its async | ||
| counterpart) over a dedicated client authenticated with the extension secret; that | ||
| client's base URL is supplied by the owning HTTP client through :meth:`configure`. | ||
| """ | ||
|
|
||
| _token_cache: ClassVar[dict[CacheKey, _CachedToken]] = {} | ||
| _sync_locks: ClassVar[dict[CacheKey, threading.Lock]] = {} | ||
| _async_locks: ClassVar[dict[CacheKey, asyncio.Lock]] = {} | ||
|
|
||
| def __init__( | ||
| self, | ||
| secret: str, | ||
| account_id: str, | ||
| min_remaining_validity_seconds: int = DEFAULT_TOKEN_VALIDITY_LEEWAY_SECONDS, | ||
| ) -> None: | ||
| """Initialize the provider. | ||
|
|
||
| Args: | ||
| secret: Extension secret used to authenticate token requests. | ||
| account_id: Account the requested token is scoped to. | ||
| min_remaining_validity_seconds: Proactive refresh leeway before the JWT ``exp``. | ||
| """ | ||
| super().__init__(secret) | ||
| self._account_id = account_id | ||
| self._min_remaining_validity_seconds = min_remaining_validity_seconds | ||
|
|
||
| @classmethod | ||
| def clear_cache(cls) -> None: | ||
| """Clear all cached account tokens and refresh locks.""" | ||
| cls._token_cache.clear() | ||
| cls._sync_locks.clear() | ||
| cls._async_locks.clear() | ||
|
|
||
| @override | ||
| def sync_auth_flow( | ||
| self, request: httpx.Request | ||
| ) -> Generator[httpx.Request, httpx.Response, None]: | ||
| """Attach an account-scoped token, refreshing it proactively and on 401.""" | ||
| token = self._token_sync() | ||
| request.headers["Authorization"] = f"Bearer {token}" | ||
| response = yield request | ||
| if response.status_code == httpx.codes.UNAUTHORIZED: | ||
| rejected = request.headers["Authorization"].removeprefix("Bearer ") | ||
| request.headers["Authorization"] = f"Bearer {self._token_sync(rejected)}" | ||
| yield request | ||
|
|
||
| @override | ||
| async def async_auth_flow( | ||
| self, request: httpx.Request | ||
| ) -> AsyncGenerator[httpx.Request, httpx.Response]: | ||
| """Attach an account-scoped token, refreshing it proactively and on 401.""" | ||
| token = await self._token_async() | ||
| request.headers["Authorization"] = f"Bearer {token}" | ||
| response = yield request | ||
| if response.status_code == httpx.codes.UNAUTHORIZED: | ||
| rejected = request.headers["Authorization"].removeprefix("Bearer ") | ||
| refreshed = await self._token_async(rejected) | ||
| request.headers["Authorization"] = f"Bearer {refreshed}" | ||
| yield request | ||
|
|
||
| @property | ||
| def _cache_key(self) -> CacheKey: | ||
| """Return the shared-cache key for this provider's scope.""" | ||
| return self._secret, self._account_id | ||
|
|
||
| def _token_sync(self, rejected: str | None = None) -> str: | ||
| """Return a usable token, fetching one under a per-account lock when needed.""" | ||
| cached = self._token_cache.get(self._cache_key) | ||
| if self._is_usable(cached, rejected): | ||
| return cached.token # type: ignore[union-attr] | ||
|
|
||
| lock = self._sync_locks.setdefault(self._cache_key, threading.Lock()) | ||
| with lock: | ||
| cached = self._token_cache.get(self._cache_key) | ||
| if self._is_usable(cached, rejected): | ||
| return cached.token # type: ignore[union-attr] | ||
| fetched = self._get_sync_service().token(self._account_id) | ||
| return self._store(fetched.token) | ||
|
|
||
| async def _token_async(self, rejected: str | None = None) -> str: | ||
| """Return a usable token, fetching one under a per-account lock when needed.""" | ||
| cached = self._token_cache.get(self._cache_key) | ||
| if self._is_usable(cached, rejected): | ||
| return cached.token # type: ignore[union-attr] | ||
|
|
||
| lock = self._async_locks.setdefault(self._cache_key, asyncio.Lock()) | ||
| async with lock: | ||
| cached = self._token_cache.get(self._cache_key) | ||
| if self._is_usable(cached, rejected): | ||
| return cached.token # type: ignore[union-attr] | ||
| fetched = await self._get_async_service().token(self._account_id) | ||
| return self._store(fetched.token) | ||
|
|
||
| def _is_usable(self, cached: _CachedToken | None, rejected: str | None) -> bool: | ||
| """Return whether the cached token can be reused for the current request. | ||
|
|
||
| A token is unusable when it is missing, when it equals a token the server just | ||
| rejected, or when it is within the proactive refresh leeway of its expiry. Tokens | ||
| without a readable ``exp`` are reused and rely on the reactive ``401`` path. | ||
| """ | ||
| if cached is None or cached.token == rejected: | ||
| return False | ||
| if cached.expires_at is None: | ||
| return True | ||
| threshold = dt.datetime.now(dt.UTC).timestamp() + self._min_remaining_validity_seconds | ||
| return cached.expires_at.timestamp() > threshold | ||
|
|
||
| def _store(self, token: str | None) -> str: | ||
| """Cache a freshly fetched token, evicting expired entries, and return it.""" | ||
| if not token: | ||
| raise MPTError("Installations token endpoint returned an empty token.") | ||
| self._token_cache[self._cache_key] = _CachedToken(token, self._read_expiry(token)) | ||
| self._evict_expired() | ||
| return token | ||
|
|
||
| def _evict_expired(self) -> None: | ||
| """Drop cache entries (and their locks) whose tokens have already expired.""" | ||
| now = dt.datetime.now(dt.UTC) | ||
| expired_keys = [ | ||
| key | ||
| for key, cached in self._token_cache.items() | ||
| if cached.expires_at is not None and cached.expires_at <= now | ||
| ] | ||
| for key in expired_keys: | ||
| self._token_cache.pop(key, None) | ||
| self._sync_locks.pop(key, None) | ||
| self._async_locks.pop(key, None) | ||
|
Comment on lines
+152
to
+171
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. π§© Analysis chainπ Script executed: #!/bin/bash
# First, locate and read the AccountScopedAuthentication class
find . -name "account_scoped.py" -type f | head -20Repository: softwareone-platform/mpt-api-python-client Length of output: 126 π Script executed: #!/bin/bash
# Read the account_scoped.py file to understand the implementation
cat -n mpt_api_client/auth/account_scoped.pyRepository: softwareone-platform/mpt-api-python-client Length of output: 8785 π Script executed: #!/bin/bash
# Search for all references to _token_cache, _sync_locks, _async_locks across the codebase
rg "(_token_cache|_sync_locks|_async_locks)" -n --context 2Repository: softwareone-platform/mpt-api-python-client Length of output: 4855 π Script executed: #!/bin/bash
# Look for threading patterns and Lock usage in the auth module
rg "threading\.|Lock\(\)|asyncio\.Lock\(\)" mpt_api_client/auth/ -n --context 1Repository: softwareone-platform/mpt-api-python-client Length of output: 855 π Script executed: #!/bin/bash
# Check if the _cache_guard fix has been applied
grep -n "_cache_guard" mpt_api_client/auth/account_scoped.pyRepository: softwareone-platform/mpt-api-python-client Length of output: 68 π Script executed: #!/bin/bash
# Look for any test coverage of concurrent account access
grep -r "concurrent\|thread\|async.*test" tests/unit/auth/test_account_scoped.py -iRepository: softwareone-platform/mpt-api-python-client Length of output: 217 π Script executed: #!/bin/bash
# Check what the concurrency test actually tests
sed -n '210,280p' tests/unit/auth/test_account_scoped.pyRepository: softwareone-platform/mpt-api-python-client Length of output: 2213 π Script executed: #!/bin/bash
# Check if there are any tests mixing different accounts with concurrent access
grep -A 30 "test_serialized_refresh" tests/unit/auth/test_account_scoped.pyRepository: softwareone-platform/mpt-api-python-client Length of output: 1320 π Script executed: #!/bin/bash
# Check if there are any tests that exercise multiple accounts concurrently
grep -B 5 -A 15 "ACC-OLD\|ACC-NEW" tests/unit/auth/test_account_scoped.py | head -40Repository: softwareone-platform/mpt-api-python-client Length of output: 715 π Script executed: #!/bin/bash
# Look for documentation or usage patterns showing multiple accounts
grep -r "AccountScopedAuthentication" --include="*.py" --include="*.md" | grep -v "test_" | grep -v ".pyc"Repository: softwareone-platform/mpt-api-python-client Length of output: 752 π Script executed: #!/bin/bash
# Check if the concurrency test uses multiple accounts
grep -B 5 -A 20 "test_serialized_refresh_fetches_token_once_under_concurrency" tests/unit/auth/test_account_scoped.pyRepository: softwareone-platform/mpt-api-python-client Length of output: 1065 π Script executed: #!/bin/bash
# Look at the full test file structure to understand concurrency coverage
grep "^def test_\|^async def test_" tests/unit/auth/test_account_scoped.pyRepository: softwareone-platform/mpt-api-python-client Length of output: 670 π Script executed: #!/bin/bash
# Check the full test_different_accounts_fetch_separate_tokens to see if it's concurrent
grep -B 5 -A 25 "test_different_accounts_fetch_separate_tokens" tests/unit/auth/test_account_scoped.pyRepository: softwareone-platform/mpt-api-python-client Length of output: 1136 π Script executed: #!/bin/bash
# Run the test suite to verify the current state of the code
# First, let's check if there's a way to run tests and see if they would catch this race
grep -r "pytest\|test" docs/testing.md | head -10Repository: softwareone-platform/mpt-api-python-client Length of output: 715 Serialize global cache eviction to prevent concurrent-dict runtime failures. Line 165 iterates over The test Proposed fix class AccountScopedAuthentication(InstallationTokenAuthentication): # noqa: WPS214
@@
_token_cache: ClassVar[dict[CacheKey, _CachedToken]] = {}
_sync_locks: ClassVar[dict[CacheKey, threading.Lock]] = {}
_async_locks: ClassVar[dict[CacheKey, asyncio.Lock]] = {}
+ _cache_guard: ClassVar[threading.Lock] = threading.Lock()
@@
def clear_cache(cls) -> None:
"""Clear all cached account tokens and refresh locks."""
- cls._token_cache.clear()
- cls._sync_locks.clear()
- cls._async_locks.clear()
+ with cls._cache_guard:
+ cls._token_cache.clear()
+ cls._sync_locks.clear()
+ cls._async_locks.clear()
@@
def _store(self, token: str | None) -> str:
"""Cache a freshly fetched token, evicting expired entries, and return it."""
if not token:
raise MPTError("Installations token endpoint returned an empty token.")
- self._token_cache[self._cache_key] = _CachedToken(token, self._read_expiry(token))
- self._evict_expired()
+ with self._cache_guard:
+ self._token_cache[self._cache_key] = _CachedToken(token, self._read_expiry(token))
+ self._evict_expired()
return tokenπ€ Prompt for AI Agents |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Validate non-negative refresh leeway at construction.
min_remaining_validity_secondsis used directly in expiry threshold math (Line 149). If a negative value is passed, expired tokens can be treated as reusable. Add a constructor guard so invalid input fails fast.Proposed fix
def __init__( self, secret: str, account_id: str, min_remaining_validity_seconds: int = DEFAULT_TOKEN_VALIDITY_LEEWAY_SECONDS, ) -> None: @@ """ super().__init__(secret) + if min_remaining_validity_seconds < 0: + raise ValueError("min_remaining_validity_seconds must be >= 0") self._account_id = account_id self._min_remaining_validity_seconds = min_remaining_validity_secondsAlso applies to: 149-150
π€ Prompt for AI Agents