diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 9d3bee5c6..06931041e 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -113,10 +113,6 @@ jobs: ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} HUGGING_FACE_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }} - GATEWAY_AUTH_ISSUER: ${{ secrets.GATEWAY_AUTH_ISSUER }} - GATEWAY_AUTH_AUDIENCE: ${{ secrets.GATEWAY_AUTH_AUDIENCE }} - GATEWAY_AUTH_CLIENT_ID: ${{ secrets.GATEWAY_AUTH_CLIENT_ID }} - GATEWAY_AUTH_CLIENT_SECRET: ${{ secrets.GATEWAY_AUTH_CLIENT_SECRET }} docker: name: Docker runs-on: ubuntu-latest diff --git a/gcp/export.py b/gcp/export.py index 4515a98d5..b4eaafc00 100644 --- a/gcp/export.py +++ b/gcp/export.py @@ -14,10 +14,6 @@ ANTHROPIC_API_KEY = os.environ["ANTHROPIC_API_KEY"] OPENAI_API_KEY = os.environ["OPENAI_API_KEY"] HUGGING_FACE_TOKEN = os.environ["HUGGING_FACE_TOKEN"] -GATEWAY_AUTH_ISSUER = os.environ["GATEWAY_AUTH_ISSUER"] -GATEWAY_AUTH_AUDIENCE = os.environ["GATEWAY_AUTH_AUDIENCE"] -GATEWAY_AUTH_CLIENT_ID = os.environ["GATEWAY_AUTH_CLIENT_ID"] -GATEWAY_AUTH_CLIENT_SECRET = os.environ["GATEWAY_AUTH_CLIENT_SECRET"] # Export GAE to to .gac.json and DB_PD to .dbpw in the current directory @@ -39,14 +35,6 @@ dockerfile = dockerfile.replace(".anthropic_api_key", ANTHROPIC_API_KEY) dockerfile = dockerfile.replace(".openai_api_key", OPENAI_API_KEY) dockerfile = dockerfile.replace(".hugging_face_token", HUGGING_FACE_TOKEN) - dockerfile = dockerfile.replace(".gateway_auth_issuer", GATEWAY_AUTH_ISSUER) - dockerfile = dockerfile.replace(".gateway_auth_audience", GATEWAY_AUTH_AUDIENCE) - dockerfile = dockerfile.replace( - ".gateway_auth_client_id", GATEWAY_AUTH_CLIENT_ID - ) - dockerfile = dockerfile.replace( - ".gateway_auth_client_secret", GATEWAY_AUTH_CLIENT_SECRET - ) with open(dockerfile_location, "w") as f: f.write(dockerfile) diff --git a/gcp/policyengine_api/Dockerfile b/gcp/policyengine_api/Dockerfile index c040cbc92..9b8d4e174 100644 --- a/gcp/policyengine_api/Dockerfile +++ b/gcp/policyengine_api/Dockerfile @@ -9,10 +9,6 @@ ENV ANTHROPIC_API_KEY .anthropic_api_key ENV OPENAI_API_KEY .openai_api_key ENV HUGGING_FACE_TOKEN .hugging_face_token ENV CREDENTIALS_JSON_API_V2 .credentials_json_api_v2 -ENV GATEWAY_AUTH_ISSUER .gateway_auth_issuer -ENV GATEWAY_AUTH_AUDIENCE .gateway_auth_audience -ENV GATEWAY_AUTH_CLIENT_ID .gateway_auth_client_id -ENV GATEWAY_AUTH_CLIENT_SECRET .gateway_auth_client_secret WORKDIR /app diff --git a/policyengine_api/libs/gateway_auth.py b/policyengine_api/libs/gateway_auth.py deleted file mode 100644 index 8e4b1eb30..000000000 --- a/policyengine_api/libs/gateway_auth.py +++ /dev/null @@ -1,191 +0,0 @@ -"""Auth0 client_credentials support for outbound calls to the simulation gateway. - -The simulation API gateway (``policyengine-api-v2``) gates its write and -job-status endpoints behind a bearer JWT minted by the PolicyEngine Auth0 -tenant. This module fetches that token for the v1 API process, caches it in -memory, and attaches it to every outbound HTTP call via an ``httpx.Auth`` -implementation. -""" - -from __future__ import annotations - -import os -import threading -import time -from typing import Optional - -import httpx - - -GATEWAY_AUTH_ISSUER_ENV = "GATEWAY_AUTH_ISSUER" -GATEWAY_AUTH_AUDIENCE_ENV = "GATEWAY_AUTH_AUDIENCE" -GATEWAY_AUTH_CLIENT_ID_ENV = "GATEWAY_AUTH_CLIENT_ID" -GATEWAY_AUTH_CLIENT_SECRET_ENV = "GATEWAY_AUTH_CLIENT_SECRET" - -GATEWAY_AUTH_ENV_VARS = ( - GATEWAY_AUTH_ISSUER_ENV, - GATEWAY_AUTH_AUDIENCE_ENV, - GATEWAY_AUTH_CLIENT_ID_ENV, - GATEWAY_AUTH_CLIENT_SECRET_ENV, -) - - -class GatewayAuthError(RuntimeError): - """Raised when the gateway auth config is missing or the token fetch fails.""" - - -def _require_all_or_none_gateway_auth_env() -> None: - """Refuse to start when the four GATEWAY_AUTH_* env vars are partially set. - - A typo in one GH Action secret name would otherwise silently degrade to - unauthenticated gateway calls, which is the exact scenario this module - exists to prevent. - """ - present = [name for name in GATEWAY_AUTH_ENV_VARS if os.environ.get(name)] - if present and len(present) != len(GATEWAY_AUTH_ENV_VARS): - missing = [name for name in GATEWAY_AUTH_ENV_VARS if not os.environ.get(name)] - raise GatewayAuthError( - "Gateway auth is partially configured: " - f"{', '.join(present)} set but {', '.join(missing)} missing. " - "Set all four or none." - ) - - -class GatewayAuthTokenProvider: - """Fetch and cache an Auth0 ``client_credentials`` access token. - - The provider is thread-safe and refreshes the token a short window before - its advertised expiry so concurrent workers cannot race and observe an - expired token between validation and use. A single instance can (and - should) be shared across many HTTP clients. - """ - - _REFRESH_MARGIN_SECONDS = 60 - - def __init__( - self, - issuer: Optional[str] = None, - audience: Optional[str] = None, - client_id: Optional[str] = None, - client_secret: Optional[str] = None, - *, - http_timeout: float = 10.0, - ): - self._issuer = ( - issuer - if issuer is not None - else os.environ.get(GATEWAY_AUTH_ISSUER_ENV, "") - ).rstrip("/") - self._audience = ( - audience - if audience is not None - else os.environ.get(GATEWAY_AUTH_AUDIENCE_ENV, "") - ) - self._client_id = ( - client_id - if client_id is not None - else os.environ.get(GATEWAY_AUTH_CLIENT_ID_ENV, "") - ) - self._client_secret = ( - client_secret - if client_secret is not None - else os.environ.get(GATEWAY_AUTH_CLIENT_SECRET_ENV, "") - ) - self._http_timeout = http_timeout - self._token: Optional[str] = None - self._expires_at: float = 0.0 - self._lock = threading.Lock() - - @property - def configured(self) -> bool: - """True iff all four required env vars / kwargs were provided.""" - return all((self._issuer, self._audience, self._client_id, self._client_secret)) - - def get_token(self) -> str: - """Return a valid bearer token, fetching or refreshing as needed.""" - if not self.configured: - raise GatewayAuthError( - "Gateway auth not configured: set " - f"{GATEWAY_AUTH_ISSUER_ENV}, {GATEWAY_AUTH_AUDIENCE_ENV}, " - f"{GATEWAY_AUTH_CLIENT_ID_ENV}, and " - f"{GATEWAY_AUTH_CLIENT_SECRET_ENV}." - ) - with self._lock: - now = time.time() - if ( - self._token is None - or now >= self._expires_at - self._REFRESH_MARGIN_SECONDS - ): - self._fetch_locked() - return self._token # type: ignore[return-value] - - def _fetch_locked(self) -> None: - """Call Auth0's ``/oauth/token``. Caller must hold ``_lock``.""" - try: - response = httpx.post( - f"{self._issuer}/oauth/token", - json={ - "client_id": self._client_id, - "client_secret": self._client_secret, - "audience": self._audience, - "grant_type": "client_credentials", - }, - timeout=self._http_timeout, - ) - except httpx.RequestError as exc: - raise GatewayAuthError(f"Auth0 token fetch network error: {exc}") from exc - - try: - response.raise_for_status() - except httpx.HTTPStatusError as exc: - raise GatewayAuthError( - f"Auth0 token fetch failed: HTTP {response.status_code}" - ) from exc - - data = response.json() - token = data.get("access_token") - if not token: - raise GatewayAuthError("Auth0 response missing access_token") - # Clamp expires_in so a pathological short/zero value from Auth0 - # cannot drive the refresh check into perpetual refetching under - # concurrent load. - raw_expires_in = data.get("expires_in") - if raw_expires_in is None: - raise GatewayAuthError("Auth0 response missing expires_in") - expires_in = max(int(raw_expires_in), self._REFRESH_MARGIN_SECONDS * 2) - self._token = token - self._expires_at = time.time() + expires_in - - def invalidate(self) -> None: - """Drop the cached token so the next ``get_token`` call refetches. - - Intended for recovery after a 401 from the gateway (e.g. the Auth0 - signing key rotated) rather than routine use. - """ - with self._lock: - self._token = None - self._expires_at = 0.0 - - -class GatewayBearerAuth(httpx.Auth): - """``httpx.Auth`` adapter that attaches a refreshed bearer token per request. - - Implements httpx's two-yield retry contract: on a 401 the cached token is - invalidated and a single retry is made with a freshly fetched token. This - covers the common case of Auth0 rotating its JWKS while a long-lived v1 - worker holds a stale token. - """ - - def __init__(self, token_provider: GatewayAuthTokenProvider): - self._token_provider = token_provider - - def auth_flow(self, request): - request.headers["Authorization"] = f"Bearer {self._token_provider.get_token()}" - response = yield request - - if response.status_code != 401: - return - - self._token_provider.invalidate() - request.headers["Authorization"] = f"Bearer {self._token_provider.get_token()}" - yield request diff --git a/policyengine_api/libs/simulation_api_modal.py b/policyengine_api/libs/simulation_api_modal.py index 2492e583e..4cf0b1423 100644 --- a/policyengine_api/libs/simulation_api_modal.py +++ b/policyengine_api/libs/simulation_api_modal.py @@ -12,11 +12,6 @@ import httpx from policyengine_api.gcp_logging import logger -from policyengine_api.libs.gateway_auth import ( - GatewayAuthTokenProvider, - GatewayBearerAuth, - _require_all_or_none_gateway_auth_env, -) @dataclass @@ -52,24 +47,7 @@ def __init__(self): "SIMULATION_API_URL", "https://policyengine--policyengine-simulation-gateway-web-app.modal.run", ) - self._token_provider = GatewayAuthTokenProvider() - _require_all_or_none_gateway_auth_env() - auth = ( - GatewayBearerAuth(self._token_provider) - if self._token_provider.configured - else None - ) - if auth is None: - logger.log_struct( - { - "message": ( - "SimulationAPIModal initialised without gateway auth; " - "all GATEWAY_AUTH_* env vars are unset." - ), - }, - severity="WARNING", - ) - self.client = httpx.Client(timeout=30.0, auth=auth) + self.client = httpx.Client(timeout=30.0) def run(self, payload: dict) -> ModalSimulationExecution: """ diff --git a/tests/unit/libs/test_gateway_auth.py b/tests/unit/libs/test_gateway_auth.py deleted file mode 100644 index d2e382425..000000000 --- a/tests/unit/libs/test_gateway_auth.py +++ /dev/null @@ -1,388 +0,0 @@ -"""Unit tests for :mod:`policyengine_api.libs.gateway_auth`.""" - -from __future__ import annotations - -import threading -import time -from unittest.mock import MagicMock, patch - -import httpx -import pytest - -from policyengine_api.libs.gateway_auth import ( - GATEWAY_AUTH_ENV_VARS, - GatewayAuthError, - GatewayAuthTokenProvider, - GatewayBearerAuth, - _require_all_or_none_gateway_auth_env, -) - - -ISSUER = "https://policyengine.uk.auth0.com" -AUDIENCE = "https://sim-gateway.policyengine.org" -CLIENT_ID = "test-client-id" -CLIENT_SECRET = "test-client-secret" - - -def _make_token_response(token: str, expires_in: int = 86400) -> MagicMock: - response = MagicMock() - response.status_code = 200 - response.raise_for_status = MagicMock() - response.json.return_value = { - "access_token": token, - "expires_in": expires_in, - "token_type": "Bearer", - } - return response - - -class TestGatewayAuthTokenProvider: - class TestConfigured: - def test__given_all_kwargs__then_configured_true(self): - provider = GatewayAuthTokenProvider( - issuer=ISSUER, - audience=AUDIENCE, - client_id=CLIENT_ID, - client_secret=CLIENT_SECRET, - ) - - assert provider.configured is True - - def test__given_missing_client_secret__then_configured_false(self): - provider = GatewayAuthTokenProvider( - issuer=ISSUER, - audience=AUDIENCE, - client_id=CLIENT_ID, - client_secret="", - ) - - assert provider.configured is False - - def test__given_empty_env__then_configured_false(self): - with patch.dict( - "os.environ", - { - "GATEWAY_AUTH_ISSUER": "", - "GATEWAY_AUTH_AUDIENCE": "", - "GATEWAY_AUTH_CLIENT_ID": "", - "GATEWAY_AUTH_CLIENT_SECRET": "", - }, - clear=False, - ): - provider = GatewayAuthTokenProvider() - - assert provider.configured is False - - class TestGetToken: - def test__given_unconfigured_provider__then_raises(self): - provider = GatewayAuthTokenProvider( - issuer="", audience="", client_id="", client_secret="" - ) - - with pytest.raises(GatewayAuthError): - provider.get_token() - - def test__given_first_call__then_fetches_and_returns_token(self): - provider = GatewayAuthTokenProvider( - issuer=ISSUER, - audience=AUDIENCE, - client_id=CLIENT_ID, - client_secret=CLIENT_SECRET, - ) - - with patch( - "policyengine_api.libs.gateway_auth.httpx.post", - return_value=_make_token_response("tok-1"), - ) as mock_post: - token = provider.get_token() - - assert token == "tok-1" - mock_post.assert_called_once() - call_kwargs = mock_post.call_args.kwargs - assert call_kwargs["json"]["grant_type"] == "client_credentials" - assert call_kwargs["json"]["client_id"] == CLIENT_ID - assert call_kwargs["json"]["client_secret"] == CLIENT_SECRET - assert call_kwargs["json"]["audience"] == AUDIENCE - assert call_kwargs["timeout"] == 10.0 - positional_url = mock_post.call_args.args[0] - assert positional_url == f"{ISSUER}/oauth/token" - - def test__given_trailing_slash_issuer__then_no_double_slash(self): - provider = GatewayAuthTokenProvider( - issuer=f"{ISSUER}/", - audience=AUDIENCE, - client_id=CLIENT_ID, - client_secret=CLIENT_SECRET, - ) - - with patch( - "policyengine_api.libs.gateway_auth.httpx.post", - return_value=_make_token_response("tok-1"), - ) as mock_post: - provider.get_token() - - url = mock_post.call_args.args[0] - assert url == f"{ISSUER}/oauth/token" - - def test__given_fresh_cached_token__then_second_call_reuses(self): - provider = GatewayAuthTokenProvider( - issuer=ISSUER, - audience=AUDIENCE, - client_id=CLIENT_ID, - client_secret=CLIENT_SECRET, - ) - - with patch( - "policyengine_api.libs.gateway_auth.httpx.post", - return_value=_make_token_response("tok-1", expires_in=86400), - ) as mock_post: - first = provider.get_token() - second = provider.get_token() - - assert first == second == "tok-1" - mock_post.assert_called_once() - - def test__given_expired_token__then_second_call_refreshes(self): - provider = GatewayAuthTokenProvider( - issuer=ISSUER, - audience=AUDIENCE, - client_id=CLIENT_ID, - client_secret=CLIENT_SECRET, - ) - - responses = [ - _make_token_response("tok-1", expires_in=60), - _make_token_response("tok-2", expires_in=86400), - ] - with patch( - "policyengine_api.libs.gateway_auth.httpx.post", - side_effect=responses, - ) as mock_post: - first = provider.get_token() - # Simulate wall-clock advancing past the refresh margin. - provider._expires_at = time.time() - 1 - second = provider.get_token() - - assert first == "tok-1" - assert second == "tok-2" - assert mock_post.call_count == 2 - - def test__given_auth0_http_error__then_raises_gateway_auth_error(self): - provider = GatewayAuthTokenProvider( - issuer=ISSUER, - audience=AUDIENCE, - client_id=CLIENT_ID, - client_secret=CLIENT_SECRET, - ) - - mock_response = MagicMock() - mock_response.status_code = 401 - mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( - "unauthorized", - request=MagicMock(), - response=mock_response, - ) - - with patch( - "policyengine_api.libs.gateway_auth.httpx.post", - return_value=mock_response, - ): - with pytest.raises(GatewayAuthError): - provider.get_token() - - def test__given_response_without_access_token__then_raises(self): - provider = GatewayAuthTokenProvider( - issuer=ISSUER, - audience=AUDIENCE, - client_id=CLIENT_ID, - client_secret=CLIENT_SECRET, - ) - - response = MagicMock() - response.status_code = 200 - response.raise_for_status = MagicMock() - response.json.return_value = {"token_type": "Bearer"} - - with patch( - "policyengine_api.libs.gateway_auth.httpx.post", - return_value=response, - ): - with pytest.raises(GatewayAuthError): - provider.get_token() - - class TestInvalidate: - def test__given_invalidated_token__then_next_call_refetches(self): - provider = GatewayAuthTokenProvider( - issuer=ISSUER, - audience=AUDIENCE, - client_id=CLIENT_ID, - client_secret=CLIENT_SECRET, - ) - - responses = [ - _make_token_response("tok-1"), - _make_token_response("tok-2"), - ] - with patch( - "policyengine_api.libs.gateway_auth.httpx.post", - side_effect=responses, - ) as mock_post: - provider.get_token() - provider.invalidate() - provider.get_token() - - assert mock_post.call_count == 2 - - class TestFailureModes: - def test__given_network_error__then_raises_gateway_auth_error(self): - provider = GatewayAuthTokenProvider( - issuer=ISSUER, - audience=AUDIENCE, - client_id=CLIENT_ID, - client_secret=CLIENT_SECRET, - ) - - with patch( - "policyengine_api.libs.gateway_auth.httpx.post", - side_effect=httpx.ConnectError("boom"), - ): - with pytest.raises(GatewayAuthError): - provider.get_token() - - def test__given_missing_expires_in__then_raises(self): - provider = GatewayAuthTokenProvider( - issuer=ISSUER, - audience=AUDIENCE, - client_id=CLIENT_ID, - client_secret=CLIENT_SECRET, - ) - response = MagicMock() - response.status_code = 200 - response.raise_for_status = MagicMock() - response.json.return_value = {"access_token": "tok"} - - with patch( - "policyengine_api.libs.gateway_auth.httpx.post", - return_value=response, - ): - with pytest.raises(GatewayAuthError): - provider.get_token() - - def test__given_zero_expires_in__then_clamped_to_refresh_margin(self): - provider = GatewayAuthTokenProvider( - issuer=ISSUER, - audience=AUDIENCE, - client_id=CLIENT_ID, - client_secret=CLIENT_SECRET, - ) - - with patch( - "policyengine_api.libs.gateway_auth.httpx.post", - return_value=_make_token_response("tok", expires_in=0), - ): - provider.get_token() - - ttl = provider._expires_at - time.time() - assert ttl > provider._REFRESH_MARGIN_SECONDS - - class TestThreadSafety: - def test__given_concurrent_callers__then_fetches_once(self): - provider = GatewayAuthTokenProvider( - issuer=ISSUER, - audience=AUDIENCE, - client_id=CLIENT_ID, - client_secret=CLIENT_SECRET, - ) - - def slow_response(*_args, **_kwargs): - time.sleep(0.05) - return _make_token_response("tok-concurrent") - - with patch( - "policyengine_api.libs.gateway_auth.httpx.post", - side_effect=slow_response, - ) as mock_post: - tokens: list[str] = [] - threads = [ - threading.Thread(target=lambda: tokens.append(provider.get_token())) - for _ in range(20) - ] - for thread in threads: - thread.start() - for thread in threads: - thread.join() - - assert mock_post.call_count == 1 - assert tokens == ["tok-concurrent"] * 20 - - -class TestRequireAllOrNoneGatewayAuthEnv: - def test__given_no_env__then_ok(self, monkeypatch): - for name in GATEWAY_AUTH_ENV_VARS: - monkeypatch.delenv(name, raising=False) - - _require_all_or_none_gateway_auth_env() - - def test__given_all_env__then_ok(self, monkeypatch): - for name in GATEWAY_AUTH_ENV_VARS: - monkeypatch.setenv(name, "x") - - _require_all_or_none_gateway_auth_env() - - def test__given_partial_env__then_raises(self, monkeypatch): - monkeypatch.setenv("GATEWAY_AUTH_ISSUER", "https://tenant.auth0.com") - monkeypatch.setenv("GATEWAY_AUTH_AUDIENCE", "aud") - monkeypatch.delenv("GATEWAY_AUTH_CLIENT_ID", raising=False) - monkeypatch.delenv("GATEWAY_AUTH_CLIENT_SECRET", raising=False) - - with pytest.raises(GatewayAuthError): - _require_all_or_none_gateway_auth_env() - - -class TestGatewayBearerAuthRetry: - def test__given_401_response__then_invalidates_and_retries_with_fresh_token(self): - provider = MagicMock() - provider.get_token.side_effect = ["stale-token", "fresh-token"] - auth = GatewayBearerAuth(provider) - - request = httpx.Request("GET", "https://example.invalid/jobs/abc") - flow = auth.auth_flow(request) - - first_request = next(flow) - assert first_request.headers["Authorization"] == "Bearer stale-token" - - unauthorized = httpx.Response(401, request=first_request) - retry_request = flow.send(unauthorized) - - assert retry_request.headers["Authorization"] == "Bearer fresh-token" - provider.invalidate.assert_called_once() - with pytest.raises(StopIteration): - flow.send(httpx.Response(200, request=retry_request)) - - def test__given_2xx_response__then_no_retry(self): - provider = MagicMock() - provider.get_token.return_value = "tok" - auth = GatewayBearerAuth(provider) - - request = httpx.Request("GET", "https://example.invalid/jobs/abc") - flow = auth.auth_flow(request) - - next(flow) - with pytest.raises(StopIteration): - flow.send(httpx.Response(200, request=request)) - - provider.invalidate.assert_not_called() - - -class TestGatewayBearerAuth: - def test__given_request__then_attaches_bearer_token_header(self): - provider = MagicMock() - provider.get_token.return_value = "tok-xyz" - auth = GatewayBearerAuth(provider) - - request = httpx.Request("GET", "https://example.invalid/") - flow = auth.auth_flow(request) - next(flow) - - assert request.headers["Authorization"] == "Bearer tok-xyz" - provider.get_token.assert_called_once() diff --git a/tests/unit/libs/test_simulation_api_modal.py b/tests/unit/libs/test_simulation_api_modal.py index 0ee3e0248..10b278c82 100644 --- a/tests/unit/libs/test_simulation_api_modal.py +++ b/tests/unit/libs/test_simulation_api_modal.py @@ -120,53 +120,6 @@ def test__given_env_var_not_set__then_uses_default_url(self, mock_httpx_client): assert "policyengine-simulation-gateway" in api.base_url assert "modal.run" in api.base_url - def test__given_gateway_auth_env_vars__then_attaches_bearer_auth( - self, mock_httpx_client, monkeypatch - ): - from policyengine_api.libs.simulation_api_modal import httpx as modal_httpx - from policyengine_api.libs.gateway_auth import GatewayBearerAuth - - monkeypatch.setenv("GATEWAY_AUTH_ISSUER", "https://tenant.auth0.com") - monkeypatch.setenv("GATEWAY_AUTH_AUDIENCE", "https://sim-gateway") - monkeypatch.setenv("GATEWAY_AUTH_CLIENT_ID", "id") - monkeypatch.setenv("GATEWAY_AUTH_CLIENT_SECRET", "secret") - - SimulationAPIModal() - - _, kwargs = modal_httpx.Client.call_args - assert isinstance(kwargs.get("auth"), GatewayBearerAuth) - - def test__given_missing_gateway_auth_env_vars__then_no_auth_attached( - self, mock_httpx_client, monkeypatch - ): - from policyengine_api.libs.simulation_api_modal import httpx as modal_httpx - - for key in ( - "GATEWAY_AUTH_ISSUER", - "GATEWAY_AUTH_AUDIENCE", - "GATEWAY_AUTH_CLIENT_ID", - "GATEWAY_AUTH_CLIENT_SECRET", - ): - monkeypatch.delenv(key, raising=False) - - SimulationAPIModal() - - _, kwargs = modal_httpx.Client.call_args - assert kwargs.get("auth") is None - - def test__given_partial_gateway_auth_env_vars__then_raises( - self, mock_httpx_client, monkeypatch - ): - from policyengine_api.libs.gateway_auth import GatewayAuthError - - monkeypatch.setenv("GATEWAY_AUTH_ISSUER", "https://tenant.auth0.com") - monkeypatch.setenv("GATEWAY_AUTH_AUDIENCE", "aud") - monkeypatch.delenv("GATEWAY_AUTH_CLIENT_ID", raising=False) - monkeypatch.delenv("GATEWAY_AUTH_CLIENT_SECRET", raising=False) - - with pytest.raises(GatewayAuthError): - SimulationAPIModal() - class TestRun: def test__given_valid_payload__then_returns_execution_with_job_id( self,