diff --git a/.github/workflows/regenerate.yml b/.github/workflows/regenerate.yml index e818d7e..46cd5fe 100644 --- a/.github/workflows/regenerate.yml +++ b/.github/workflows/regenerate.yml @@ -113,6 +113,86 @@ jobs: - name: Patch ApiClient close lifecycle run: python3 scripts/patch_api_client_close.py + - name: Verify JWT-exchange code survived regeneration + run: | + python3 - <<'PY' + import ast, pathlib, sys + + errors = [] + + # 1. The hand-written, regen-immune auth module must survive. + if not pathlib.Path("hotdata/_auth.py").is_file(): + errors.append("hotdata/_auth.py is missing (regen overwrote/dropped it)") + + config = pathlib.Path("hotdata/configuration.py") + if not config.is_file(): + errors.append("hotdata/configuration.py is missing") + else: + tree = ast.parse(config.read_text()) + cls = next( + (n for n in tree.body + if isinstance(n, ast.ClassDef) and n.name == "Configuration"), + None, + ) + if cls is None: + errors.append("Configuration class not found in configuration.py") + else: + # 2. api_key must be a property (decorated getter), so every + # request transparently exchanges for a fresh JWT. + api_key_is_property = any( + isinstance(n, ast.FunctionDef) + and n.name == "api_key" + and any( + isinstance(d, ast.Name) and d.id == "property" + for d in n.decorator_list + ) + for n in cls.body + ) + if not api_key_is_property: + errors.append("Configuration.api_key is not a @property (template drift)") + + # 3. The token manager must be created eagerly in __init__ + # (lazy creation has a concurrent-first-request race). + init = next( + (n for n in cls.body + if isinstance(n, ast.FunctionDef) and n.name == "__init__"), + None, + ) + init_src = ast.get_source_segment(config.read_text(), init) if init else "" + if "self._token_manager = _TokenManager(" not in (init_src or ""): + errors.append("eager self._token_manager assignment missing from __init__") + + # 4. __deepcopy__ must skip _token_manager (lock + PoolManager + # are not deepcopy-able) and rebuild it. + deepcopy = next( + (n for n in cls.body + if isinstance(n, ast.FunctionDef) and n.name == "__deepcopy__"), + None, + ) + if deepcopy is None: + errors.append("__deepcopy__ missing from Configuration") + else: + # Look for _token_manager as a real identifier/string in the + # body (AST, so comments mentioning it don't count) — proves + # the lock/PoolManager skip-and-rebuild actually survived. + refs = any( + (isinstance(n, ast.Constant) and n.value == "_token_manager") + or (isinstance(n, ast.Attribute) and n.attr == "_token_manager") + for n in ast.walk(deepcopy) + ) + if not refs: + errors.append("__deepcopy__ does not skip/rebuild _token_manager") + + if errors: + print("::error::JWT-exchange regen-safety check failed:") + for e in errors: + print(f" - {e}") + sys.exit(1) + print("JWT-exchange code survived regeneration: " + "_auth.py present, api_key property, eager _token_manager, " + "__deepcopy__ handling all intact.") + PY + - name: Clean up generated artifacts run: | rm -f openapi.yaml diff --git a/.openapi-generator-ignore b/.openapi-generator-ignore index dbd3899..816fbf6 100644 --- a/.openapi-generator-ignore +++ b/.openapi-generator-ignore @@ -2,3 +2,4 @@ git_push.sh README.md setup.py +hotdata/_auth.py diff --git a/.openapi-generator-templates/configuration.mustache b/.openapi-generator-templates/configuration.mustache index acd05b2..0b6ac91 100644 --- a/.openapi-generator-templates/configuration.mustache +++ b/.openapi-generator-templates/configuration.mustache @@ -321,7 +321,12 @@ conf = {{packageName}}.Configuration( self.temp_folder_path = None """Temp file folder for downloading files """ - self.api_key = api_key + # Transparent API-token -> JWT exchange. `api_key` is a property whose + # getter returns a live JWT minted from this credential (see _auth.py); + # the manager is created eagerly here (never lazily in the getter) so + # concurrent first requests don't each build one. The setter rebuilds it. + from {{packageName}}._auth import _TokenManager + self._token_manager = _TokenManager(api_key, self) if api_key is not None else None """Hotdata API key, sent as `Authorization: Bearer `.""" # apiKey-security values (X-Workspace-Id, X-Session-Id), keyed by # scheme name. Read by the generated `auth_settings()` below. @@ -451,13 +456,20 @@ conf = {{packageName}}.Configuration( result = cls.__new__(cls) memo[id(self)] = result for k, v in self.__dict__.items(): - if k not in ('logger', 'logger_file_handler'): + # _token_manager holds a threading.Lock and a urllib3 PoolManager, + # neither of which is deepcopy-able; rebuild it below from the + # (deepcopy-safe) credential string instead. + if k not in ('logger', 'logger_file_handler', '_token_manager'): setattr(result, k, copy.deepcopy(v, memo)) # shallow copy of loggers result.logger = copy.copy(self.logger) # use setters to configure loggers result.logger_file = self.logger_file result.debug = self.debug + # rebuild the token manager bound to the copy (never deepcopy lock/pool) + from {{packageName}}._auth import _TokenManager + tm = self._token_manager + result._token_manager = _TokenManager(tm._credential, result) if tm else None return result def __setattr__(self, name: str, value: Any) -> None: @@ -608,6 +620,26 @@ conf = {{packageName}}.Configuration( return None + @property + def api_key(self) -> Optional[str]: + """Live bearer credential, sent as `Authorization: Bearer `. + + Backed by the regeneration-immune `_TokenManager` (see `{{packageName}}._auth`): + an opaque API token is transparently exchanged for a short-lived JWT and + kept fresh, while a credential already shaped like a JWT (or exchange + opted out) is returned unchanged. `auth_settings()` reads this on every + request, so the wire always carries a current token. + """ + # Read the manager once: a concurrent `api_key` reset could otherwise + # set it to None between the check and the `.bearer_value()` call. + tm = self._token_manager + return None if tm is None else tm.bearer_value() + + @api_key.setter + def api_key(self, value: Optional[str]) -> None: + from {{packageName}}._auth import _TokenManager + self._token_manager = _TokenManager(value, self) if value is not None else None + @property def workspace_id(self) -> Optional[str]: """Public id of the target workspace (sent as `X-Workspace-Id`).""" @@ -689,7 +721,11 @@ conf = {{packageName}}.Configuration( } {{/isBasicBasic}} {{#isBasicBearer}} - if self.api_key is not None: + # Resolve the bearer token once: `api_key` is a property that may mint a + # JWT and take the token-manager lock, so a second read would lock twice + # and could race a concurrent `api_key` reset (yielding `Bearer None`). + {{name}}_token = self.api_key + if {{name}}_token is not None: auth['{{name}}'] = { 'type': 'bearer', 'in': 'header', @@ -697,7 +733,7 @@ conf = {{packageName}}.Configuration( 'format': '{{.}}', {{/bearerFormat}} 'key': 'Authorization', - 'value': 'Bearer ' + self.api_key + 'value': 'Bearer ' + {{name}}_token } {{/isBasicBearer}} {{#isHttpSignature}} diff --git a/hotdata/_auth.py b/hotdata/_auth.py new file mode 100644 index 0000000..6c47ba3 --- /dev/null +++ b/hotdata/_auth.py @@ -0,0 +1,233 @@ +"""Transparent API-token -> JWT exchange for the Hotdata Python SDK. + +Hotdata is moving API authentication to short-lived JWTs. Users still configure +the SDK with their long-lived ``hd_`` API token, but every request should carry +a fresh JWT instead. This module is the hand-written, regeneration-immune piece +that makes that happen behind the scenes: :class:`_TokenManager` exchanges the +API token for a JWT at ``POST {host}/v1/auth/jwt`` and keeps it fresh, mirroring +the CLI's ``jwt.rs`` logic so the CLI and SDK behave identically. + +OpenAPI Generator only rewrites the files it generates, so a hand-added module +like this one (precedent: :mod:`hotdata.arrow`) survives regeneration. It is +additionally listed in ``.openapi-generator-ignore`` as belt-and-suspenders. + +Key behaviors: + +* **Pass-through** -- a credential that already looks like a JWT (``eyJ`` + prefix, matching the Gateway's own ``^Bearer eyJ.*`` detection) is returned + unchanged and never exchanged. Every other (opaque) credential is treated as + an API token and exchanged; set ``HOTDATA_DISABLE_JWT_EXCHANGE`` to force a + raw, non-JWT credential through as-is (local/dev setups, rollback). +* **Opt-out** -- if ``HOTDATA_DISABLE_JWT_EXCHANGE`` is set to an affirmative + value (``1``/``true``/``yes``/``on``), the credential is always returned + as-is (hard escape hatch for rollout); ``0``/``false``/empty do not opt out. +* **In-memory cache only** -- no disk writes. The server already de-duplicates + mints (keyed by ``sha256(api_token)``), so per-process caching is sufficient. +* **Thread-safe** -- a :class:`threading.Lock` with single-flight mint covers + the case where a shared ``ApiClient`` is hit from many threads at once. +* **Refresh, then re-mint** -- prefer the refresh token when available; on + refresh failure, re-mint from the held API token (always possible since the + SDK holds it). Matches the CLI. +* **TLS/proxy reuse** -- the exchange call reuses the SDK's configured TLS, + client cert and proxy settings (see :func:`_pool_from_config`) so it behaves + like every other SDK request, with a bounded timeout so a stalled token + endpoint fails fast instead of hanging every call. +""" + +import json +import os +import ssl +import threading +import time +from urllib.parse import urlencode + +import urllib3 + +_LEEWAY = 30 # refresh when <30s of life remains +_TIMEOUT = 30.0 # seconds -- never let a stalled token endpoint hang every request +_CLIENT_ID = "hotdata-python-sdk" + +# Env var that disables exchange entirely. Used as a hard escape hatch during +# the rollout window and for local/dev setups. Only affirmative values opt out +# (see _DISABLE_VALUES) so that ``=0`` / ``=false`` do NOT silently disable it. +_DISABLE_ENV = "HOTDATA_DISABLE_JWT_EXCHANGE" +_DISABLE_VALUES = {"1", "true", "yes", "on"} + +# The SOCKS schemes urllib3 routes through SOCKSProxyManager rather than the +# plain ProxyManager. Mirrors hotdata/rest.py's SUPPORTED_SOCKS_PROXIES. +_SUPPORTED_SOCKS_PROXIES = {"socks5", "socks5h", "socks4", "socks4a"} + + +class TokenExchangeError(Exception): + """Raised when an API token cannot be exchanged for a JWT. + + Surfacing ``invalid_grant`` (expired/revoked API token) here keeps the + failure clear instead of a confusing downstream 401. + """ + + +def _is_socks_proxy_url(url): + # Mirror hotdata/rest.py.is_socks_proxy_url so the exchange pool routes + # SOCKS proxies the same way the generated REST client does. + if url is None: + return False + split_section = url.split("://") + if len(split_section) < 2: + return False + return split_section[0].lower() in _SUPPORTED_SOCKS_PROXIES + + +def _pool_from_config(configuration): + """Build a urllib3 pool manager from the SDK's TLS/proxy configuration. + + Deliberately parallels ``RESTClientObject.__init__`` in + :mod:`hotdata.rest` so the token-exchange call honors the same + ``ssl_ca_cert`` / ``ca_cert_data`` / ``cert_file`` / ``key_file`` / + ``proxy`` / ``verify_ssl`` settings as every other SDK request. We build a + fresh, lightweight pool here rather than reaching into the ``ApiClient``'s + REST client (which the ``Configuration`` does not hold a reference to). + """ + # cert_reqs -- honor verify_ssl exactly as the generated client does. + if configuration.verify_ssl: + cert_reqs = ssl.CERT_REQUIRED + else: + cert_reqs = ssl.CERT_NONE + + pool_args = { + "cert_reqs": cert_reqs, + "ca_certs": configuration.ssl_ca_cert, + "cert_file": configuration.cert_file, + "key_file": configuration.key_file, + "ca_cert_data": configuration.ca_cert_data, + } + # Mirror rest.py's hostname/SNI handling so the exchange call does not + # silently fail for users who customize them (corporate MITM proxies set + # assert_hostname; some gateways require an explicit tls_server_name/SNI). + if configuration.assert_hostname is not None: + pool_args["assert_hostname"] = configuration.assert_hostname + if configuration.tls_server_name: + pool_args["server_hostname"] = configuration.tls_server_name + if configuration.socket_options is not None: + pool_args["socket_options"] = configuration.socket_options + # `retries`/`maxsize` are intentionally not mirrored: the exchange is a + # single bounded-timeout request that fails fast rather than retrying. + + if configuration.proxy: + if _is_socks_proxy_url(configuration.proxy): + from urllib3.contrib.socks import SOCKSProxyManager + pool_args["proxy_url"] = configuration.proxy + pool_args["headers"] = configuration.proxy_headers + return SOCKSProxyManager(**pool_args) + pool_args["proxy_url"] = configuration.proxy + pool_args["proxy_headers"] = configuration.proxy_headers + return urllib3.ProxyManager(**pool_args) + + return urllib3.PoolManager(**pool_args) + + +class _TokenManager: + """Exchanges an API token for short-lived JWTs and keeps them fresh. + + A credential that already looks like a JWT (``eyJ`` prefix) is passed + through unchanged, as is any credential when + ``HOTDATA_DISABLE_JWT_EXCHANGE`` is set; every other (opaque) API token is + exchanged. + """ + + def __init__(self, credential, configuration, pool=None): + self._credential = credential + self._config = configuration # read host + TLS lazily at mint time + self._pool = pool # injected in tests; else built from config TLS + self._lock = threading.Lock() + self._jwt = None + self._exp = 0.0 + self._refresh = None + + @property + def _needs_exchange(self): + # Opt-out wins outright: an affirmative HOTDATA_DISABLE_JWT_EXCHANGE + # (1/true/yes/on) means send the credential as-is, never touching the + # token endpoint. Other values (incl. 0/false/empty) do not opt out. + if os.environ.get(_DISABLE_ENV, "").strip().lower() in _DISABLE_VALUES: + return False + # A compact JWT always starts with "eyJ" (base64 of '{"'), matching the + # Gateway's own ``^Bearer eyJ.*`` detection -- those already are what we + # want on the wire, so pass them through. Everything else is an opaque + # API token to be exchanged. (Hotdata API tokens are bare hex; the + # ``hd_`` prefix seen in docs/comments is cosmetic and not enforced by + # the server, so we must not gate on it.) Use HOTDATA_DISABLE_JWT_EXCHANGE + # to force a raw, non-JWT credential through unchanged (local/dev). + return isinstance(self._credential, str) and not self._credential.startswith("eyJ") + + def bearer_value(self): + """Return a live JWT (exchanging + caching), or the credential as-is. + + Returns the credential unchanged when it is already a JWT or when + exchange is disabled; otherwise returns a cached JWT, refreshing or + re-minting it when it is within ``_LEEWAY`` seconds of expiry. + """ + if not self._needs_exchange: + return self._credential # already a JWT (or opt-out) -> unchanged + with self._lock: + # Fast path: a still-valid cached JWT, no network call. + if self._jwt and time.time() < self._exp - _LEEWAY: + return self._jwt + # Prefer the refresh token; on failure, drop it and re-mint below. + if self._refresh and not self._mint( + {"grant_type": "refresh_token", "refresh_token": self._refresh} + ): + self._refresh = None # refresh failed -> fall through to re-mint + # Re-mint from the held API token if we still lack a fresh JWT. + if not self._jwt or time.time() >= self._exp - _LEEWAY: + self._mint({"grant_type": "api_token", "api_token": self._credential}) + return self._jwt + + def _mint(self, params): + # Returns True on success. The refresh path is best-effort: ANY failure + # -- a non-200, a transport error, or a malformed/missing-token body -- + # returns False so the caller re-mints from the held API token. An + # api_token mint instead raises TokenExchangeError on any failure, since + # there is no further fallback. + params["client_id"] = _CLIENT_ID + is_refresh = params["grant_type"] == "refresh_token" + try: + pool = self._pool or _pool_from_config(self._config) # reuses ssl_ca_cert/cert/proxy + host = self._config.host.rstrip("/") # read host lazily -- may be set post-construct + resp = pool.request( + "POST", + f"{host}/v1/auth/jwt", + body=urlencode(params), + headers={"Content-Type": "application/x-www-form-urlencoded"}, + timeout=_TIMEOUT, + ) + if resp.status != 200: + raise TokenExchangeError( + f"token exchange failed: {resp.status} {resp.data[:200]!r}" + ) + data = json.loads(resp.data) + token = data["access_token"] + expires_in = float(data.get("expires_in", 300)) + except ( + TokenExchangeError, + urllib3.exceptions.HTTPError, + ValueError, + TypeError, + KeyError, + ) as exc: + if is_refresh: + return False # let caller re-mint from the API token + if isinstance(exc, TokenExchangeError): + raise + raise TokenExchangeError(f"token exchange failed: {exc!r}") from exc + self._jwt = token + self._exp = time.time() + expires_in + self._refresh = data.get("refresh_token") or self._refresh + return True + + +__all__ = [ + "TokenExchangeError", + "_TokenManager", + "_CLIENT_ID", + "_pool_from_config", +] diff --git a/hotdata/configuration.py b/hotdata/configuration.py index 2fd18b4..1759006 100644 --- a/hotdata/configuration.py +++ b/hotdata/configuration.py @@ -226,7 +226,12 @@ def __init__( self.temp_folder_path = None """Temp file folder for downloading files """ - self.api_key = api_key + # Transparent API-token -> JWT exchange. `api_key` is a property whose + # getter returns a live JWT minted from this credential (see _auth.py); + # the manager is created eagerly here (never lazily in the getter) so + # concurrent first requests don't each build one. The setter rebuilds it. + from hotdata._auth import _TokenManager + self._token_manager = _TokenManager(api_key, self) if api_key is not None else None """Hotdata API key, sent as `Authorization: Bearer `.""" # apiKey-security values (X-Workspace-Id, X-Session-Id), keyed by # scheme name. Read by the generated `auth_settings()` below. @@ -339,13 +344,20 @@ def __deepcopy__(self, memo: Dict[int, Any]) -> Self: result = cls.__new__(cls) memo[id(self)] = result for k, v in self.__dict__.items(): - if k not in ('logger', 'logger_file_handler'): + # _token_manager holds a threading.Lock and a urllib3 PoolManager, + # neither of which is deepcopy-able; rebuild it below from the + # (deepcopy-safe) credential string instead. + if k not in ('logger', 'logger_file_handler', '_token_manager'): setattr(result, k, copy.deepcopy(v, memo)) # shallow copy of loggers result.logger = copy.copy(self.logger) # use setters to configure loggers result.logger_file = self.logger_file result.debug = self.debug + # rebuild the token manager bound to the copy (never deepcopy lock/pool) + from hotdata._auth import _TokenManager + tm = self._token_manager + result._token_manager = _TokenManager(tm._credential, result) if tm else None return result def __setattr__(self, name: str, value: Any) -> None: @@ -490,6 +502,26 @@ def get_api_key_with_prefix(self, identifier: str, alias: Optional[str]=None) -> return None + @property + def api_key(self) -> Optional[str]: + """Live bearer credential, sent as `Authorization: Bearer `. + + Backed by the regeneration-immune `_TokenManager` (see `hotdata._auth`): + an opaque API token is transparently exchanged for a short-lived JWT and + kept fresh, while a credential already shaped like a JWT (or exchange + opted out) is returned unchanged. `auth_settings()` reads this on every + request, so the wire always carries a current token. + """ + # Read the manager once: a concurrent `api_key` reset could otherwise + # set it to None between the check and the `.bearer_value()` call. + tm = self._token_manager + return None if tm is None else tm.bearer_value() + + @api_key.setter + def api_key(self, value: Optional[str]) -> None: + from hotdata._auth import _TokenManager + self._token_manager = _TokenManager(value, self) if value is not None else None + @property def workspace_id(self) -> Optional[str]: """Public id of the target workspace (sent as `X-Workspace-Id`).""" @@ -540,12 +572,16 @@ def auth_settings(self)-> AuthSettings: :return: The Auth Settings information dict. """ auth: AuthSettings = {} - if self.api_key is not None: + # Resolve the bearer token once: `api_key` is a property that may mint a + # JWT and take the token-manager lock, so a second read would lock twice + # and could race a concurrent `api_key` reset (yielding `Bearer None`). + BearerAuth_token = self.api_key + if BearerAuth_token is not None: auth['BearerAuth'] = { 'type': 'bearer', 'in': 'header', 'key': 'Authorization', - 'value': 'Bearer ' + self.api_key + 'value': 'Bearer ' + BearerAuth_token } if 'WorkspaceId' in self.api_keys: auth['WorkspaceId'] = { diff --git a/tests/test_arrow.py b/tests/test_arrow.py index 27b7dda..b67188d 100644 --- a/tests/test_arrow.py +++ b/tests/test_arrow.py @@ -82,6 +82,11 @@ def _install_fake_response( ) -> None: """Replace RESTClientObject.request with a stub that records the call.""" + # The api_key used here ("test-key") is a dummy, not a real token. Disable + # transparent JWT exchange so auth_settings() does not try to mint one + # against a non-existent endpoint when the (stubbed) request is built. + monkeypatch.setenv("HOTDATA_DISABLE_JWT_EXCHANGE", "1") + from hotdata import rest def fake_request( diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..5c181c0 --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,616 @@ +"""Unit tests for hotdata._auth._TokenManager. + +These tests exercise the transparent JWT-exchange logic in complete isolation +from the network. Every test injects a *fake pool* via the documented +``_TokenManager(credential, configuration, pool=...)`` parameter, so no real +``urllib3.PoolManager`` is ever built and no socket is ever opened. + +They verify the pinned public contract: + +* first mint -- an opaque (non-JWT) credential POSTs an ``api_token`` grant + to ``/v1/auth/jwt`` (form-encoded, correct Content-Type and + ``client_id``) and returns the minted ``access_token``; +* cache hit -- a second ``bearer_value()`` within TTL does not re-hit the + pool; +* near-expiry -- with < ``_LEEWAY`` seconds of life left and a refresh token + held, a ``refresh_token`` grant is POSTed; +* refresh fail -- a non-200 refresh falls back to a fresh ``api_token`` mint; +* ``eyJ`` pass -- a credential starting with ``eyJ`` is returned unchanged and + the pool is never called; +* exchange error -- a non-200 ``api_token`` mint raises ``TokenExchangeError``; +* opt-out -- ``HOTDATA_DISABLE_JWT_EXCHANGE`` returns the credential as-is; +* concurrency -- N racing threads cause exactly one mint (single-flight); +* deepcopy -- a ``Configuration``/manager round-trips through deepcopy + despite the lock+pool, and the copy still mints. +""" + +from __future__ import annotations + +import copy +import json +import threading +import time +from typing import Any, Dict, List, Optional +from urllib.parse import parse_qs + +import pytest + +from hotdata import Configuration +from hotdata._auth import ( + _CLIENT_ID, + _LEEWAY, + TokenExchangeError, + _TokenManager, + _pool_from_config, +) + + +# -------------------------------------------------------------------------- +# Test doubles +# -------------------------------------------------------------------------- + + +class _FakeResponse: + """Minimal stand-in for a urllib3.HTTPResponse. + + ``_TokenManager._mint`` only reads ``.status`` and ``.data`` (the latter + being raw JSON bytes), so that is all we model. + """ + + def __init__(self, status: int, payload: Any): + self.status = status + if isinstance(payload, (bytes, bytearray)): + self.data = bytes(payload) + else: + self.data = json.dumps(payload).encode() + + +class _FakePool: + """Records every ``request(...)`` call and returns scripted responses. + + Each call pops the next response from ``responses``; if the list is + exhausted the last response is reused (handy for "always succeeds" cases). + A ``pre_request`` hook lets concurrency tests slow the first mint down to + force a thread race. + """ + + def __init__(self, responses: List[_FakeResponse], pre_request=None): + self._responses = list(responses) + self.calls: List[Dict[str, Any]] = [] + self._lock = threading.Lock() + self._pre_request = pre_request + + def request( + self, + method: str, + url: str, + body: Optional[Any] = None, + headers: Optional[Dict[str, str]] = None, + timeout: Optional[float] = None, + ) -> _FakeResponse: + if self._pre_request is not None: + self._pre_request() + with self._lock: + self.calls.append( + { + "method": method, + "url": url, + "body": body, + "headers": dict(headers or {}), + "timeout": timeout, + } + ) + if len(self._responses) > 1: + return self._responses.pop(0) + return self._responses[0] + + +def _config(host: str = "https://api.hotdata.test") -> Configuration: + """Build a Configuration exactly the way a user would.""" + return Configuration(host=host, api_key="hd_unused", workspace_id="ws_test") + + +def _mint_response( + access_token: str = "eyJ.minted.jwt", + *, + refresh_token: Optional[str] = "rt_opaque", + expires_in: int = 300, +) -> _FakeResponse: + payload: Dict[str, Any] = { + "access_token": access_token, + "token_type": "Bearer", + "scope": "permission:read_write", + } + if refresh_token is not None: + payload["refresh_token"] = refresh_token + if expires_in is not None: + payload["expires_in"] = expires_in + return _FakeResponse(200, payload) + + +def _form(body: Any) -> Dict[str, List[str]]: + """Decode an x-www-form-urlencoded request body into a dict.""" + if isinstance(body, (bytes, bytearray)): + body = body.decode() + return parse_qs(body) + + +def _bearer_from(auth: Dict[str, Any]) -> str: + """Pull the ``Authorization: Bearer ...`` value out of auth_settings().""" + for setting in auth.values(): + value = str(setting.get("value", "")) + if value.startswith("Bearer "): + return value + raise AssertionError(f"no Bearer auth setting found in {auth!r}") + + +# -------------------------------------------------------------------------- +# First mint +# -------------------------------------------------------------------------- + + +def test_first_mint_posts_api_token_grant() -> None: + pool = _FakePool([_mint_response(access_token="eyJ.first.jwt")]) + cfg = _config() + mgr = _TokenManager("hd_secret_token", cfg, pool=pool) + + token = mgr.bearer_value() + + assert token == "eyJ.first.jwt" + assert len(pool.calls) == 1 + + call = pool.calls[0] + assert call["method"] == "POST" + assert call["url"] == "https://api.hotdata.test/v1/auth/jwt" + assert call["headers"]["Content-Type"] == "application/x-www-form-urlencoded" + + form = _form(call["body"]) + assert form["grant_type"] == ["api_token"] + assert form["api_token"] == ["hd_secret_token"] + assert form["client_id"] == [_CLIENT_ID] + assert _CLIENT_ID == "hotdata-python-sdk" + # The raw API token must never leak into the URL or headers. + assert "hd_secret_token" not in call["url"] + + +def test_host_read_lazily_and_trailing_slash_stripped() -> None: + """Host is read at mint time (so a late ``config.host = ...`` is honored) + and a trailing slash is trimmed before composing the endpoint.""" + pool = _FakePool([_mint_response()]) + cfg = _config(host="https://placeholder.invalid") + mgr = _TokenManager("hd_secret_token", cfg, pool=pool) + + # Reconfigure host after the manager was constructed. + cfg.host = "https://late.hotdata.test/" + mgr.bearer_value() + + assert pool.calls[0]["url"] == "https://late.hotdata.test/v1/auth/jwt" + + +# -------------------------------------------------------------------------- +# Cache hit +# -------------------------------------------------------------------------- + + +def test_second_call_within_ttl_is_cache_hit() -> None: + pool = _FakePool([_mint_response(access_token="eyJ.cached.jwt", expires_in=300)]) + mgr = _TokenManager("hd_secret_token", _config(), pool=pool) + + first = mgr.bearer_value() + second = mgr.bearer_value() + + assert first == second == "eyJ.cached.jwt" + # The cached JWT is reused; the pool is hit exactly once. + assert len(pool.calls) == 1 + + +# -------------------------------------------------------------------------- +# Near-expiry refresh +# -------------------------------------------------------------------------- + + +def test_near_expiry_uses_refresh_token_grant() -> None: + # First mint returns a token expiring inside the leeway window, plus a + # refresh token. The next bearer_value() must refresh rather than re-mint. + short_lived = _mint_response( + access_token="eyJ.short.jwt", + refresh_token="rt_first", + expires_in=_LEEWAY - 5, # already inside the refresh window + ) + refreshed = _mint_response( + access_token="eyJ.refreshed.jwt", + refresh_token="rt_second", + expires_in=300, + ) + pool = _FakePool([short_lived, refreshed]) + mgr = _TokenManager("hd_secret_token", _config(), pool=pool) + + assert mgr.bearer_value() == "eyJ.short.jwt" + assert mgr.bearer_value() == "eyJ.refreshed.jwt" + + assert len(pool.calls) == 2 + refresh_form = _form(pool.calls[1]["body"]) + assert refresh_form["grant_type"] == ["refresh_token"] + assert refresh_form["refresh_token"] == ["rt_first"] + assert refresh_form["client_id"] == [_CLIENT_ID] + # A refresh grant must not carry the raw API token. + assert "api_token" not in refresh_form + + +# -------------------------------------------------------------------------- +# Refresh failure -> re-mint +# -------------------------------------------------------------------------- + + +def test_refresh_failure_falls_back_to_api_token_mint() -> None: + short_lived = _mint_response( + access_token="eyJ.short.jwt", + refresh_token="rt_doomed", + expires_in=_LEEWAY - 5, + ) + refresh_fail = _FakeResponse(400, {"error": "invalid_grant"}) + remint = _mint_response(access_token="eyJ.reminted.jwt", expires_in=300) + pool = _FakePool([short_lived, refresh_fail, remint]) + mgr = _TokenManager("hd_secret_token", _config(), pool=pool) + + assert mgr.bearer_value() == "eyJ.short.jwt" + # Second call: refresh 400 -> fall back to api_token mint. + assert mgr.bearer_value() == "eyJ.reminted.jwt" + + assert len(pool.calls) == 3 + assert _form(pool.calls[1]["body"])["grant_type"] == ["refresh_token"] + remint_form = _form(pool.calls[2]["body"]) + assert remint_form["grant_type"] == ["api_token"] + assert remint_form["api_token"] == ["hd_secret_token"] + + +# -------------------------------------------------------------------------- +# eyJ pass-through +# -------------------------------------------------------------------------- + + +def test_jwt_credential_is_passed_through_unchanged() -> None: + raw_jwt = "eyJhbGciOiJSUzI1NiJ9.payload.signature" + pool = _FakePool([_mint_response()]) + mgr = _TokenManager(raw_jwt, _config(), pool=pool) + + assert mgr.bearer_value() == raw_jwt + # A credential already shaped like a JWT must never be exchanged. + assert pool.calls == [] + + +# -------------------------------------------------------------------------- +# Exchange error +# -------------------------------------------------------------------------- + + +def test_non_200_api_token_mint_raises_token_exchange_error() -> None: + pool = _FakePool([_FakeResponse(401, {"error": "invalid_grant"})]) + mgr = _TokenManager("hd_bad_token", _config(), pool=pool) + + with pytest.raises(TokenExchangeError): + mgr.bearer_value() + + assert len(pool.calls) == 1 + assert _form(pool.calls[0]["body"])["grant_type"] == ["api_token"] + + +# -------------------------------------------------------------------------- +# Opt-out +# -------------------------------------------------------------------------- + + +def test_opt_out_env_var_returns_credential_unchanged( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("HOTDATA_DISABLE_JWT_EXCHANGE", "1") + pool = _FakePool([_mint_response()]) + mgr = _TokenManager("hd_secret_token", _config(), pool=pool) + + assert mgr.bearer_value() == "hd_secret_token" + assert pool.calls == [] + + +@pytest.mark.parametrize("value", ["1", "true", "TRUE", "yes", "on", " on "]) +def test_opt_out_affirmative_values_disable( + monkeypatch: pytest.MonkeyPatch, value: str +) -> None: + monkeypatch.setenv("HOTDATA_DISABLE_JWT_EXCHANGE", value) + pool = _FakePool([_mint_response()]) + mgr = _TokenManager("opaque_token", _config(), pool=pool) + + assert mgr.bearer_value() == "opaque_token" + assert pool.calls == [] + + +@pytest.mark.parametrize("value", ["0", "false", "no", "off", ""]) +def test_opt_out_non_affirmative_values_still_exchange( + monkeypatch: pytest.MonkeyPatch, value: str +) -> None: + """``=0`` / ``=false`` etc. must NOT silently disable exchange -- a footgun + if users set them expecting to *enable* it. Exchange still happens.""" + monkeypatch.setenv("HOTDATA_DISABLE_JWT_EXCHANGE", value) + pool = _FakePool([_mint_response(access_token="eyJ.minted.jwt")]) + mgr = _TokenManager("opaque_token", _config(), pool=pool) + + assert mgr.bearer_value() == "eyJ.minted.jwt" + assert len(pool.calls) == 1 + + +# -------------------------------------------------------------------------- +# Concurrency: single-flight mint +# -------------------------------------------------------------------------- + + +def test_concurrent_callers_trigger_exactly_one_mint() -> None: + n_threads = 16 + start = threading.Barrier(n_threads) + + # Slow the in-flight mint so all threads pile up on the lock and would each + # mint if single-flight were broken. + def slow() -> None: + time.sleep(0.05) + + pool = _FakePool( + [_mint_response(access_token="eyJ.single.jwt", expires_in=300)], + pre_request=slow, + ) + mgr = _TokenManager("hd_secret_token", _config(), pool=pool) + + results: List[str] = [] + results_lock = threading.Lock() + + def worker() -> None: + start.wait() + value = mgr.bearer_value() + with results_lock: + results.append(value) + + threads = [threading.Thread(target=worker) for _ in range(n_threads)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(results) == n_threads + assert set(results) == {"eyJ.single.jwt"} + # The decisive assertion: only one network mint happened despite the race. + assert len(pool.calls) == 1 + + +# -------------------------------------------------------------------------- +# Deepcopy round-trip (the lock + pool gotcha) +# -------------------------------------------------------------------------- + + +def test_configuration_deepcopy_round_trip() -> None: + """``copy.deepcopy`` of a Configuration carrying a token manager must not + choke on the manager's lock/pool, and the copy must be an independent + object with its own freshly rebuilt manager. + + The copy's manager is rebuilt from the deepcopy-safe credential string (the + lock + pool are never deep-copied), so we assert on that rather than reading + ``api_key`` -- which would trigger a real network exchange.""" + original = Configuration(api_key="hd_x") + + duplicate = copy.deepcopy(original) + + assert duplicate is not original + # Each Configuration carries its own manager; the copy's was rebuilt, not + # the same object (which would share the original's lock/pool). + assert duplicate._token_manager is not None + assert duplicate._token_manager is not original._token_manager + # The manager is bound to the copy (so it reads the copy's host at mint + # time) and the credential survived the round-trip. + assert duplicate._token_manager._config is duplicate + assert duplicate._token_manager._credential == "hd_x" + + +def test_deepcopied_manager_credential_still_mints() -> None: + """A token manager reconstructed from a deepcopy-safe credential still + produces a working bearer value, and the two managers are distinct.""" + cfg = _config() + pool_a = _FakePool([_mint_response(access_token="eyJ.a.jwt")]) + mgr_a = _TokenManager("hd_secret_token", cfg, pool=pool_a) + + # Mimic the __deepcopy__ contract: rebuild from the credential string + # rather than deep-copying the lock/pool. + pool_b = _FakePool([_mint_response(access_token="eyJ.b.jwt")]) + mgr_b = _TokenManager("hd_secret_token", copy.deepcopy(cfg), pool=pool_b) + + assert mgr_a is not mgr_b + assert mgr_a.bearer_value() == "eyJ.a.jwt" + assert mgr_b.bearer_value() == "eyJ.b.jwt" + assert len(pool_a.calls) == 1 + assert len(pool_b.calls) == 1 + + +# -------------------------------------------------------------------------- +# Opaque (non-JWT) credentials are exchanged -- the prefix is not gated +# -------------------------------------------------------------------------- + + +def test_bare_hex_token_is_exchanged() -> None: + """Hotdata API tokens are bare hex with no ``hd_`` prefix (the prefix in + the docs is cosmetic and not enforced by the server). Any opaque, non-JWT + credential must therefore be exchanged, not passed through.""" + raw = "8a4bfd9cfa6926344f770d6b9a093c2b559dafc4de2a69137acb93e7e9821c7b" + pool = _FakePool([_mint_response(access_token="eyJ.minted.jwt")]) + mgr = _TokenManager(raw, _config(), pool=pool) + + assert mgr.bearer_value() == "eyJ.minted.jwt" + assert len(pool.calls) == 1 + assert _form(pool.calls[0]["body"])["api_token"] == [raw] + + +def test_configuration_exchanges_bare_token_then_opt_out_passes_through( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """End-to-end at the Configuration level: a bare token is exchanged so + ``auth_settings()`` carries the minted JWT; with the opt-out env var set the + raw token is sent unchanged (the arrow-test style dummy-key setup).""" + raw = "8a4bfd9c0bare0token" + + # Exchange path: auth_settings() carries the minted JWT, not the raw token. + pool = _FakePool([_mint_response(access_token="eyJ.live.jwt")]) + cfg = Configuration(host="https://api.hotdata.test", api_key=raw) + cfg._token_manager._pool = pool + assert _bearer_from(cfg.auth_settings()) == "Bearer eyJ.live.jwt" + assert len(pool.calls) == 1 + + # Opt-out path: same raw token, exchange disabled -> sent as-is, no mint. + monkeypatch.setenv("HOTDATA_DISABLE_JWT_EXCHANGE", "1") + pool2 = _FakePool([_mint_response()]) + cfg2 = Configuration(host="https://api.hotdata.test", api_key=raw) + cfg2._token_manager._pool = pool2 + assert _bearer_from(cfg2.auth_settings()) == f"Bearer {raw}" + assert pool2.calls == [] + + +# -------------------------------------------------------------------------- +# Configuration.api_key property + auth_settings() end-to-end +# -------------------------------------------------------------------------- + + +def test_configuration_api_key_property_and_auth_settings_use_jwt() -> None: + """The whole point of the design: ``Configuration.api_key`` returns a live + JWT and ``auth_settings()`` assembles ``Authorization: Bearer `` from + it -- the regen-critical path that must keep working.""" + pool = _FakePool([_mint_response(access_token="eyJ.live.jwt", expires_in=300)]) + cfg = _config() + cfg._token_manager = _TokenManager("hd_secret_token", cfg, pool=pool) + + assert cfg.api_key == "eyJ.live.jwt" + assert _bearer_from(cfg.auth_settings()) == "Bearer eyJ.live.jwt" + # Property getter + auth_settings together mint once and then cache. + assert len(pool.calls) == 1 + + +# -------------------------------------------------------------------------- +# _pool_from_config mirrors rest.py's TLS/SNI handling +# -------------------------------------------------------------------------- + + +def test_pool_from_config_honors_assert_hostname_and_sni() -> None: + """A user who sets ``assert_hostname`` (corporate MITM) or + ``tls_server_name`` (custom SNI) must have those applied to the exchange + pool too, or the token call silently fails while normal calls work.""" + cfg = _config() + cfg.assert_hostname = False + cfg.tls_server_name = "sni.internal.test" + + pool = _pool_from_config(cfg) + kw = pool.connection_pool_kw + + assert kw.get("assert_hostname") is False + assert kw.get("server_hostname") == "sni.internal.test" + + +def test_pool_from_config_omits_hostname_args_when_unset() -> None: + """When the user has not customized them, the args are absent (so urllib3 + uses its defaults) -- mirroring rest.py's conditional adds.""" + pool = _pool_from_config(_config()) + kw = pool.connection_pool_kw + + assert "assert_hostname" not in kw + assert "server_hostname" not in kw + + +def test_pool_from_config_forwards_socket_options() -> None: + """socket_options (e.g. TCP keepalive) the user set for all SDK requests + must also apply to the exchange pool, matching rest.py.""" + import socket + + cfg = _config() + cfg.socket_options = [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)] + pool = _pool_from_config(cfg) + + assert pool.connection_pool_kw.get("socket_options") == [ + (socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + ] + + +# -------------------------------------------------------------------------- +# Malformed (non-JSON) success body +# -------------------------------------------------------------------------- + + +def test_non_json_success_body_raises_token_exchange_error() -> None: + """A 200 with a non-JSON body (e.g. a misrouted health page) surfaces as a + clear TokenExchangeError rather than a bare JSONDecodeError.""" + pool = _FakePool([_FakeResponse(200, b"not json")]) + mgr = _TokenManager("hd_secret_token", _config(), pool=pool) + + with pytest.raises(TokenExchangeError): + mgr.bearer_value() + + +def test_missing_access_token_raises_token_exchange_error() -> None: + """A 200 with valid JSON but no ``access_token`` (e.g. a misrouted endpoint + returning some other JSON document) must surface as a TokenExchangeError, + not a bare KeyError.""" + pool = _FakePool([_FakeResponse(200, {"token_type": "Bearer"})]) + mgr = _TokenManager("hd_secret_token", _config(), pool=pool) + + with pytest.raises(TokenExchangeError): + mgr.bearer_value() + + +# -------------------------------------------------------------------------- +# Refresh that fails by *raising* (not just a non-200) still re-mints +# -------------------------------------------------------------------------- + + +def test_refresh_raising_falls_back_to_api_token_mint() -> None: + """The refresh step is best-effort: if it fails in *any* way -- not just a + non-200, but a malformed/non-JSON body or a transport error -- the manager + must drop the refresh token and re-mint from the held API token rather than + letting the exception escape ``bearer_value()``.""" + short_lived = _mint_response( + access_token="eyJ.short.jwt", + refresh_token="rt_doomed", + expires_in=_LEEWAY - 5, + ) + # Refresh returns 200 but a non-JSON body -> would raise inside _mint. + refresh_garbage = _FakeResponse(200, b"oops") + remint = _mint_response(access_token="eyJ.reminted.jwt", expires_in=300) + pool = _FakePool([short_lived, refresh_garbage, remint]) + mgr = _TokenManager("hd_secret_token", _config(), pool=pool) + + assert mgr.bearer_value() == "eyJ.short.jwt" + # Second call: refresh raises internally -> fall back to api_token mint. + assert mgr.bearer_value() == "eyJ.reminted.jwt" + + assert len(pool.calls) == 3 + assert _form(pool.calls[1]["body"])["grant_type"] == ["refresh_token"] + assert _form(pool.calls[2]["body"])["grant_type"] == ["api_token"] + + +# -------------------------------------------------------------------------- +# auth_settings() reads the token exactly once (no double bearer_value()) +# -------------------------------------------------------------------------- + + +def test_auth_settings_reads_token_once(monkeypatch: pytest.MonkeyPatch) -> None: + """``auth_settings()`` must resolve the bearer token a single time, not + once for the null-check and again for the value -- otherwise it acquires the + manager lock twice per request and a concurrent ``api_key`` reset between the + two reads could yield ``'Bearer ' + None``.""" + pool = _FakePool([_mint_response(access_token="eyJ.once.jwt")]) + cfg = _config() + mgr = _TokenManager("hd_secret_token", cfg, pool=pool) + cfg._token_manager = mgr + + count = {"n": 0} + real = mgr.bearer_value + + def counting() -> str: + count["n"] += 1 + return real() + + monkeypatch.setattr(mgr, "bearer_value", counting) + + auth = cfg.auth_settings() + + assert _bearer_from(auth) == "Bearer eyJ.once.jwt" + assert count["n"] == 1