diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 06931041e..9d3bee5c6 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -113,6 +113,10 @@ jobs: ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} HUGGING_FACE_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }} + GATEWAY_AUTH_ISSUER: ${{ secrets.GATEWAY_AUTH_ISSUER }} + GATEWAY_AUTH_AUDIENCE: ${{ secrets.GATEWAY_AUTH_AUDIENCE }} + GATEWAY_AUTH_CLIENT_ID: ${{ secrets.GATEWAY_AUTH_CLIENT_ID }} + GATEWAY_AUTH_CLIENT_SECRET: ${{ secrets.GATEWAY_AUTH_CLIENT_SECRET }} docker: name: Docker runs-on: ubuntu-latest diff --git a/changelog.d/add-gateway-auth-client.added.md b/changelog.d/add-gateway-auth-client.added.md new file mode 100644 index 000000000..484145fec --- /dev/null +++ b/changelog.d/add-gateway-auth-client.added.md @@ -0,0 +1 @@ +Outbound auth to the simulation API gateway. `SimulationAPIModal` now attaches an Auth0 `client_credentials` bearer token to every request to the Modal simulation gateway, via a new `GatewayAuthTokenProvider` that caches and refreshes tokens in-process. Configured by four env vars: `GATEWAY_AUTH_ISSUER`, `GATEWAY_AUTH_AUDIENCE`, `GATEWAY_AUTH_CLIENT_ID`, `GATEWAY_AUTH_CLIENT_SECRET`. If any are unset, no auth is attached (preserves local/dev behavior). diff --git a/gcp/export.py b/gcp/export.py index b4eaafc00..4515a98d5 100644 --- a/gcp/export.py +++ b/gcp/export.py @@ -14,6 +14,10 @@ ANTHROPIC_API_KEY = os.environ["ANTHROPIC_API_KEY"] OPENAI_API_KEY = os.environ["OPENAI_API_KEY"] HUGGING_FACE_TOKEN = os.environ["HUGGING_FACE_TOKEN"] +GATEWAY_AUTH_ISSUER = os.environ["GATEWAY_AUTH_ISSUER"] +GATEWAY_AUTH_AUDIENCE = os.environ["GATEWAY_AUTH_AUDIENCE"] +GATEWAY_AUTH_CLIENT_ID = os.environ["GATEWAY_AUTH_CLIENT_ID"] +GATEWAY_AUTH_CLIENT_SECRET = os.environ["GATEWAY_AUTH_CLIENT_SECRET"] # Export GAE to to .gac.json and DB_PD to .dbpw in the current directory @@ -35,6 +39,14 @@ dockerfile = dockerfile.replace(".anthropic_api_key", ANTHROPIC_API_KEY) dockerfile = dockerfile.replace(".openai_api_key", OPENAI_API_KEY) dockerfile = dockerfile.replace(".hugging_face_token", HUGGING_FACE_TOKEN) + dockerfile = dockerfile.replace(".gateway_auth_issuer", GATEWAY_AUTH_ISSUER) + dockerfile = dockerfile.replace(".gateway_auth_audience", GATEWAY_AUTH_AUDIENCE) + dockerfile = dockerfile.replace( + ".gateway_auth_client_id", GATEWAY_AUTH_CLIENT_ID + ) + dockerfile = dockerfile.replace( + ".gateway_auth_client_secret", GATEWAY_AUTH_CLIENT_SECRET + ) with open(dockerfile_location, "w") as f: f.write(dockerfile) diff --git a/gcp/policyengine_api/Dockerfile b/gcp/policyengine_api/Dockerfile index 9b8d4e174..c040cbc92 100644 --- a/gcp/policyengine_api/Dockerfile +++ b/gcp/policyengine_api/Dockerfile @@ -9,6 +9,10 @@ ENV ANTHROPIC_API_KEY .anthropic_api_key ENV OPENAI_API_KEY .openai_api_key ENV HUGGING_FACE_TOKEN .hugging_face_token ENV CREDENTIALS_JSON_API_V2 .credentials_json_api_v2 +ENV GATEWAY_AUTH_ISSUER .gateway_auth_issuer +ENV GATEWAY_AUTH_AUDIENCE .gateway_auth_audience +ENV GATEWAY_AUTH_CLIENT_ID .gateway_auth_client_id +ENV GATEWAY_AUTH_CLIENT_SECRET .gateway_auth_client_secret WORKDIR /app diff --git a/policyengine_api/libs/gateway_auth.py b/policyengine_api/libs/gateway_auth.py new file mode 100644 index 000000000..8e4b1eb30 --- /dev/null +++ b/policyengine_api/libs/gateway_auth.py @@ -0,0 +1,191 @@ +"""Auth0 client_credentials support for outbound calls to the simulation gateway. + +The simulation API gateway (``policyengine-api-v2``) gates its write and +job-status endpoints behind a bearer JWT minted by the PolicyEngine Auth0 +tenant. This module fetches that token for the v1 API process, caches it in +memory, and attaches it to every outbound HTTP call via an ``httpx.Auth`` +implementation. +""" + +from __future__ import annotations + +import os +import threading +import time +from typing import Optional + +import httpx + + +GATEWAY_AUTH_ISSUER_ENV = "GATEWAY_AUTH_ISSUER" +GATEWAY_AUTH_AUDIENCE_ENV = "GATEWAY_AUTH_AUDIENCE" +GATEWAY_AUTH_CLIENT_ID_ENV = "GATEWAY_AUTH_CLIENT_ID" +GATEWAY_AUTH_CLIENT_SECRET_ENV = "GATEWAY_AUTH_CLIENT_SECRET" + +GATEWAY_AUTH_ENV_VARS = ( + GATEWAY_AUTH_ISSUER_ENV, + GATEWAY_AUTH_AUDIENCE_ENV, + GATEWAY_AUTH_CLIENT_ID_ENV, + GATEWAY_AUTH_CLIENT_SECRET_ENV, +) + + +class GatewayAuthError(RuntimeError): + """Raised when the gateway auth config is missing or the token fetch fails.""" + + +def _require_all_or_none_gateway_auth_env() -> None: + """Refuse to start when the four GATEWAY_AUTH_* env vars are partially set. + + A typo in one GH Action secret name would otherwise silently degrade to + unauthenticated gateway calls, which is the exact scenario this module + exists to prevent. + """ + present = [name for name in GATEWAY_AUTH_ENV_VARS if os.environ.get(name)] + if present and len(present) != len(GATEWAY_AUTH_ENV_VARS): + missing = [name for name in GATEWAY_AUTH_ENV_VARS if not os.environ.get(name)] + raise GatewayAuthError( + "Gateway auth is partially configured: " + f"{', '.join(present)} set but {', '.join(missing)} missing. " + "Set all four or none." + ) + + +class GatewayAuthTokenProvider: + """Fetch and cache an Auth0 ``client_credentials`` access token. + + The provider is thread-safe and refreshes the token a short window before + its advertised expiry so concurrent workers cannot race and observe an + expired token between validation and use. A single instance can (and + should) be shared across many HTTP clients. + """ + + _REFRESH_MARGIN_SECONDS = 60 + + def __init__( + self, + issuer: Optional[str] = None, + audience: Optional[str] = None, + client_id: Optional[str] = None, + client_secret: Optional[str] = None, + *, + http_timeout: float = 10.0, + ): + self._issuer = ( + issuer + if issuer is not None + else os.environ.get(GATEWAY_AUTH_ISSUER_ENV, "") + ).rstrip("/") + self._audience = ( + audience + if audience is not None + else os.environ.get(GATEWAY_AUTH_AUDIENCE_ENV, "") + ) + self._client_id = ( + client_id + if client_id is not None + else os.environ.get(GATEWAY_AUTH_CLIENT_ID_ENV, "") + ) + self._client_secret = ( + client_secret + if client_secret is not None + else os.environ.get(GATEWAY_AUTH_CLIENT_SECRET_ENV, "") + ) + self._http_timeout = http_timeout + self._token: Optional[str] = None + self._expires_at: float = 0.0 + self._lock = threading.Lock() + + @property + def configured(self) -> bool: + """True iff all four required env vars / kwargs were provided.""" + return all((self._issuer, self._audience, self._client_id, self._client_secret)) + + def get_token(self) -> str: + """Return a valid bearer token, fetching or refreshing as needed.""" + if not self.configured: + raise GatewayAuthError( + "Gateway auth not configured: set " + f"{GATEWAY_AUTH_ISSUER_ENV}, {GATEWAY_AUTH_AUDIENCE_ENV}, " + f"{GATEWAY_AUTH_CLIENT_ID_ENV}, and " + f"{GATEWAY_AUTH_CLIENT_SECRET_ENV}." + ) + with self._lock: + now = time.time() + if ( + self._token is None + or now >= self._expires_at - self._REFRESH_MARGIN_SECONDS + ): + self._fetch_locked() + return self._token # type: ignore[return-value] + + def _fetch_locked(self) -> None: + """Call Auth0's ``/oauth/token``. Caller must hold ``_lock``.""" + try: + response = httpx.post( + f"{self._issuer}/oauth/token", + json={ + "client_id": self._client_id, + "client_secret": self._client_secret, + "audience": self._audience, + "grant_type": "client_credentials", + }, + timeout=self._http_timeout, + ) + except httpx.RequestError as exc: + raise GatewayAuthError(f"Auth0 token fetch network error: {exc}") from exc + + try: + response.raise_for_status() + except httpx.HTTPStatusError as exc: + raise GatewayAuthError( + f"Auth0 token fetch failed: HTTP {response.status_code}" + ) from exc + + data = response.json() + token = data.get("access_token") + if not token: + raise GatewayAuthError("Auth0 response missing access_token") + # Clamp expires_in so a pathological short/zero value from Auth0 + # cannot drive the refresh check into perpetual refetching under + # concurrent load. + raw_expires_in = data.get("expires_in") + if raw_expires_in is None: + raise GatewayAuthError("Auth0 response missing expires_in") + expires_in = max(int(raw_expires_in), self._REFRESH_MARGIN_SECONDS * 2) + self._token = token + self._expires_at = time.time() + expires_in + + def invalidate(self) -> None: + """Drop the cached token so the next ``get_token`` call refetches. + + Intended for recovery after a 401 from the gateway (e.g. the Auth0 + signing key rotated) rather than routine use. + """ + with self._lock: + self._token = None + self._expires_at = 0.0 + + +class GatewayBearerAuth(httpx.Auth): + """``httpx.Auth`` adapter that attaches a refreshed bearer token per request. + + Implements httpx's two-yield retry contract: on a 401 the cached token is + invalidated and a single retry is made with a freshly fetched token. This + covers the common case of Auth0 rotating its JWKS while a long-lived v1 + worker holds a stale token. + """ + + def __init__(self, token_provider: GatewayAuthTokenProvider): + self._token_provider = token_provider + + def auth_flow(self, request): + request.headers["Authorization"] = f"Bearer {self._token_provider.get_token()}" + response = yield request + + if response.status_code != 401: + return + + self._token_provider.invalidate() + request.headers["Authorization"] = f"Bearer {self._token_provider.get_token()}" + yield request diff --git a/policyengine_api/libs/simulation_api_modal.py b/policyengine_api/libs/simulation_api_modal.py index 4cf0b1423..2492e583e 100644 --- a/policyengine_api/libs/simulation_api_modal.py +++ b/policyengine_api/libs/simulation_api_modal.py @@ -12,6 +12,11 @@ import httpx from policyengine_api.gcp_logging import logger +from policyengine_api.libs.gateway_auth import ( + GatewayAuthTokenProvider, + GatewayBearerAuth, + _require_all_or_none_gateway_auth_env, +) @dataclass @@ -47,7 +52,24 @@ def __init__(self): "SIMULATION_API_URL", "https://policyengine--policyengine-simulation-gateway-web-app.modal.run", ) - self.client = httpx.Client(timeout=30.0) + self._token_provider = GatewayAuthTokenProvider() + _require_all_or_none_gateway_auth_env() + auth = ( + GatewayBearerAuth(self._token_provider) + if self._token_provider.configured + else None + ) + if auth is None: + logger.log_struct( + { + "message": ( + "SimulationAPIModal initialised without gateway auth; " + "all GATEWAY_AUTH_* env vars are unset." + ), + }, + severity="WARNING", + ) + self.client = httpx.Client(timeout=30.0, auth=auth) def run(self, payload: dict) -> ModalSimulationExecution: """ diff --git a/tests/unit/libs/test_gateway_auth.py b/tests/unit/libs/test_gateway_auth.py new file mode 100644 index 000000000..d2e382425 --- /dev/null +++ b/tests/unit/libs/test_gateway_auth.py @@ -0,0 +1,388 @@ +"""Unit tests for :mod:`policyengine_api.libs.gateway_auth`.""" + +from __future__ import annotations + +import threading +import time +from unittest.mock import MagicMock, patch + +import httpx +import pytest + +from policyengine_api.libs.gateway_auth import ( + GATEWAY_AUTH_ENV_VARS, + GatewayAuthError, + GatewayAuthTokenProvider, + GatewayBearerAuth, + _require_all_or_none_gateway_auth_env, +) + + +ISSUER = "https://policyengine.uk.auth0.com" +AUDIENCE = "https://sim-gateway.policyengine.org" +CLIENT_ID = "test-client-id" +CLIENT_SECRET = "test-client-secret" + + +def _make_token_response(token: str, expires_in: int = 86400) -> MagicMock: + response = MagicMock() + response.status_code = 200 + response.raise_for_status = MagicMock() + response.json.return_value = { + "access_token": token, + "expires_in": expires_in, + "token_type": "Bearer", + } + return response + + +class TestGatewayAuthTokenProvider: + class TestConfigured: + def test__given_all_kwargs__then_configured_true(self): + provider = GatewayAuthTokenProvider( + issuer=ISSUER, + audience=AUDIENCE, + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + ) + + assert provider.configured is True + + def test__given_missing_client_secret__then_configured_false(self): + provider = GatewayAuthTokenProvider( + issuer=ISSUER, + audience=AUDIENCE, + client_id=CLIENT_ID, + client_secret="", + ) + + assert provider.configured is False + + def test__given_empty_env__then_configured_false(self): + with patch.dict( + "os.environ", + { + "GATEWAY_AUTH_ISSUER": "", + "GATEWAY_AUTH_AUDIENCE": "", + "GATEWAY_AUTH_CLIENT_ID": "", + "GATEWAY_AUTH_CLIENT_SECRET": "", + }, + clear=False, + ): + provider = GatewayAuthTokenProvider() + + assert provider.configured is False + + class TestGetToken: + def test__given_unconfigured_provider__then_raises(self): + provider = GatewayAuthTokenProvider( + issuer="", audience="", client_id="", client_secret="" + ) + + with pytest.raises(GatewayAuthError): + provider.get_token() + + def test__given_first_call__then_fetches_and_returns_token(self): + provider = GatewayAuthTokenProvider( + issuer=ISSUER, + audience=AUDIENCE, + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + ) + + with patch( + "policyengine_api.libs.gateway_auth.httpx.post", + return_value=_make_token_response("tok-1"), + ) as mock_post: + token = provider.get_token() + + assert token == "tok-1" + mock_post.assert_called_once() + call_kwargs = mock_post.call_args.kwargs + assert call_kwargs["json"]["grant_type"] == "client_credentials" + assert call_kwargs["json"]["client_id"] == CLIENT_ID + assert call_kwargs["json"]["client_secret"] == CLIENT_SECRET + assert call_kwargs["json"]["audience"] == AUDIENCE + assert call_kwargs["timeout"] == 10.0 + positional_url = mock_post.call_args.args[0] + assert positional_url == f"{ISSUER}/oauth/token" + + def test__given_trailing_slash_issuer__then_no_double_slash(self): + provider = GatewayAuthTokenProvider( + issuer=f"{ISSUER}/", + audience=AUDIENCE, + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + ) + + with patch( + "policyengine_api.libs.gateway_auth.httpx.post", + return_value=_make_token_response("tok-1"), + ) as mock_post: + provider.get_token() + + url = mock_post.call_args.args[0] + assert url == f"{ISSUER}/oauth/token" + + def test__given_fresh_cached_token__then_second_call_reuses(self): + provider = GatewayAuthTokenProvider( + issuer=ISSUER, + audience=AUDIENCE, + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + ) + + with patch( + "policyengine_api.libs.gateway_auth.httpx.post", + return_value=_make_token_response("tok-1", expires_in=86400), + ) as mock_post: + first = provider.get_token() + second = provider.get_token() + + assert first == second == "tok-1" + mock_post.assert_called_once() + + def test__given_expired_token__then_second_call_refreshes(self): + provider = GatewayAuthTokenProvider( + issuer=ISSUER, + audience=AUDIENCE, + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + ) + + responses = [ + _make_token_response("tok-1", expires_in=60), + _make_token_response("tok-2", expires_in=86400), + ] + with patch( + "policyengine_api.libs.gateway_auth.httpx.post", + side_effect=responses, + ) as mock_post: + first = provider.get_token() + # Simulate wall-clock advancing past the refresh margin. + provider._expires_at = time.time() - 1 + second = provider.get_token() + + assert first == "tok-1" + assert second == "tok-2" + assert mock_post.call_count == 2 + + def test__given_auth0_http_error__then_raises_gateway_auth_error(self): + provider = GatewayAuthTokenProvider( + issuer=ISSUER, + audience=AUDIENCE, + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + ) + + mock_response = MagicMock() + mock_response.status_code = 401 + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "unauthorized", + request=MagicMock(), + response=mock_response, + ) + + with patch( + "policyengine_api.libs.gateway_auth.httpx.post", + return_value=mock_response, + ): + with pytest.raises(GatewayAuthError): + provider.get_token() + + def test__given_response_without_access_token__then_raises(self): + provider = GatewayAuthTokenProvider( + issuer=ISSUER, + audience=AUDIENCE, + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + ) + + response = MagicMock() + response.status_code = 200 + response.raise_for_status = MagicMock() + response.json.return_value = {"token_type": "Bearer"} + + with patch( + "policyengine_api.libs.gateway_auth.httpx.post", + return_value=response, + ): + with pytest.raises(GatewayAuthError): + provider.get_token() + + class TestInvalidate: + def test__given_invalidated_token__then_next_call_refetches(self): + provider = GatewayAuthTokenProvider( + issuer=ISSUER, + audience=AUDIENCE, + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + ) + + responses = [ + _make_token_response("tok-1"), + _make_token_response("tok-2"), + ] + with patch( + "policyengine_api.libs.gateway_auth.httpx.post", + side_effect=responses, + ) as mock_post: + provider.get_token() + provider.invalidate() + provider.get_token() + + assert mock_post.call_count == 2 + + class TestFailureModes: + def test__given_network_error__then_raises_gateway_auth_error(self): + provider = GatewayAuthTokenProvider( + issuer=ISSUER, + audience=AUDIENCE, + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + ) + + with patch( + "policyengine_api.libs.gateway_auth.httpx.post", + side_effect=httpx.ConnectError("boom"), + ): + with pytest.raises(GatewayAuthError): + provider.get_token() + + def test__given_missing_expires_in__then_raises(self): + provider = GatewayAuthTokenProvider( + issuer=ISSUER, + audience=AUDIENCE, + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + ) + response = MagicMock() + response.status_code = 200 + response.raise_for_status = MagicMock() + response.json.return_value = {"access_token": "tok"} + + with patch( + "policyengine_api.libs.gateway_auth.httpx.post", + return_value=response, + ): + with pytest.raises(GatewayAuthError): + provider.get_token() + + def test__given_zero_expires_in__then_clamped_to_refresh_margin(self): + provider = GatewayAuthTokenProvider( + issuer=ISSUER, + audience=AUDIENCE, + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + ) + + with patch( + "policyengine_api.libs.gateway_auth.httpx.post", + return_value=_make_token_response("tok", expires_in=0), + ): + provider.get_token() + + ttl = provider._expires_at - time.time() + assert ttl > provider._REFRESH_MARGIN_SECONDS + + class TestThreadSafety: + def test__given_concurrent_callers__then_fetches_once(self): + provider = GatewayAuthTokenProvider( + issuer=ISSUER, + audience=AUDIENCE, + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + ) + + def slow_response(*_args, **_kwargs): + time.sleep(0.05) + return _make_token_response("tok-concurrent") + + with patch( + "policyengine_api.libs.gateway_auth.httpx.post", + side_effect=slow_response, + ) as mock_post: + tokens: list[str] = [] + threads = [ + threading.Thread(target=lambda: tokens.append(provider.get_token())) + for _ in range(20) + ] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + assert mock_post.call_count == 1 + assert tokens == ["tok-concurrent"] * 20 + + +class TestRequireAllOrNoneGatewayAuthEnv: + def test__given_no_env__then_ok(self, monkeypatch): + for name in GATEWAY_AUTH_ENV_VARS: + monkeypatch.delenv(name, raising=False) + + _require_all_or_none_gateway_auth_env() + + def test__given_all_env__then_ok(self, monkeypatch): + for name in GATEWAY_AUTH_ENV_VARS: + monkeypatch.setenv(name, "x") + + _require_all_or_none_gateway_auth_env() + + def test__given_partial_env__then_raises(self, monkeypatch): + monkeypatch.setenv("GATEWAY_AUTH_ISSUER", "https://tenant.auth0.com") + monkeypatch.setenv("GATEWAY_AUTH_AUDIENCE", "aud") + monkeypatch.delenv("GATEWAY_AUTH_CLIENT_ID", raising=False) + monkeypatch.delenv("GATEWAY_AUTH_CLIENT_SECRET", raising=False) + + with pytest.raises(GatewayAuthError): + _require_all_or_none_gateway_auth_env() + + +class TestGatewayBearerAuthRetry: + def test__given_401_response__then_invalidates_and_retries_with_fresh_token(self): + provider = MagicMock() + provider.get_token.side_effect = ["stale-token", "fresh-token"] + auth = GatewayBearerAuth(provider) + + request = httpx.Request("GET", "https://example.invalid/jobs/abc") + flow = auth.auth_flow(request) + + first_request = next(flow) + assert first_request.headers["Authorization"] == "Bearer stale-token" + + unauthorized = httpx.Response(401, request=first_request) + retry_request = flow.send(unauthorized) + + assert retry_request.headers["Authorization"] == "Bearer fresh-token" + provider.invalidate.assert_called_once() + with pytest.raises(StopIteration): + flow.send(httpx.Response(200, request=retry_request)) + + def test__given_2xx_response__then_no_retry(self): + provider = MagicMock() + provider.get_token.return_value = "tok" + auth = GatewayBearerAuth(provider) + + request = httpx.Request("GET", "https://example.invalid/jobs/abc") + flow = auth.auth_flow(request) + + next(flow) + with pytest.raises(StopIteration): + flow.send(httpx.Response(200, request=request)) + + provider.invalidate.assert_not_called() + + +class TestGatewayBearerAuth: + def test__given_request__then_attaches_bearer_token_header(self): + provider = MagicMock() + provider.get_token.return_value = "tok-xyz" + auth = GatewayBearerAuth(provider) + + request = httpx.Request("GET", "https://example.invalid/") + flow = auth.auth_flow(request) + next(flow) + + assert request.headers["Authorization"] == "Bearer tok-xyz" + provider.get_token.assert_called_once() diff --git a/tests/unit/libs/test_simulation_api_modal.py b/tests/unit/libs/test_simulation_api_modal.py index 10b278c82..0ee3e0248 100644 --- a/tests/unit/libs/test_simulation_api_modal.py +++ b/tests/unit/libs/test_simulation_api_modal.py @@ -120,6 +120,53 @@ def test__given_env_var_not_set__then_uses_default_url(self, mock_httpx_client): assert "policyengine-simulation-gateway" in api.base_url assert "modal.run" in api.base_url + def test__given_gateway_auth_env_vars__then_attaches_bearer_auth( + self, mock_httpx_client, monkeypatch + ): + from policyengine_api.libs.simulation_api_modal import httpx as modal_httpx + from policyengine_api.libs.gateway_auth import GatewayBearerAuth + + monkeypatch.setenv("GATEWAY_AUTH_ISSUER", "https://tenant.auth0.com") + monkeypatch.setenv("GATEWAY_AUTH_AUDIENCE", "https://sim-gateway") + monkeypatch.setenv("GATEWAY_AUTH_CLIENT_ID", "id") + monkeypatch.setenv("GATEWAY_AUTH_CLIENT_SECRET", "secret") + + SimulationAPIModal() + + _, kwargs = modal_httpx.Client.call_args + assert isinstance(kwargs.get("auth"), GatewayBearerAuth) + + def test__given_missing_gateway_auth_env_vars__then_no_auth_attached( + self, mock_httpx_client, monkeypatch + ): + from policyengine_api.libs.simulation_api_modal import httpx as modal_httpx + + for key in ( + "GATEWAY_AUTH_ISSUER", + "GATEWAY_AUTH_AUDIENCE", + "GATEWAY_AUTH_CLIENT_ID", + "GATEWAY_AUTH_CLIENT_SECRET", + ): + monkeypatch.delenv(key, raising=False) + + SimulationAPIModal() + + _, kwargs = modal_httpx.Client.call_args + assert kwargs.get("auth") is None + + def test__given_partial_gateway_auth_env_vars__then_raises( + self, mock_httpx_client, monkeypatch + ): + from policyengine_api.libs.gateway_auth import GatewayAuthError + + monkeypatch.setenv("GATEWAY_AUTH_ISSUER", "https://tenant.auth0.com") + monkeypatch.setenv("GATEWAY_AUTH_AUDIENCE", "aud") + monkeypatch.delenv("GATEWAY_AUTH_CLIENT_ID", raising=False) + monkeypatch.delenv("GATEWAY_AUTH_CLIENT_SECRET", raising=False) + + with pytest.raises(GatewayAuthError): + SimulationAPIModal() + class TestRun: def test__given_valid_payload__then_returns_execution_with_job_id( self,