Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ jobs:
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
HUGGING_FACE_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }}
GATEWAY_AUTH_ISSUER: ${{ secrets.GATEWAY_AUTH_ISSUER }}
GATEWAY_AUTH_AUDIENCE: ${{ secrets.GATEWAY_AUTH_AUDIENCE }}
GATEWAY_AUTH_CLIENT_ID: ${{ secrets.GATEWAY_AUTH_CLIENT_ID }}
GATEWAY_AUTH_CLIENT_SECRET: ${{ secrets.GATEWAY_AUTH_CLIENT_SECRET }}
docker:
name: Docker
runs-on: ubuntu-latest
Expand Down
1 change: 1 addition & 0 deletions changelog.d/add-gateway-auth-client.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Outbound auth to the simulation API gateway. `SimulationAPIModal` now attaches an Auth0 `client_credentials` bearer token to every request to the Modal simulation gateway, via a new `GatewayAuthTokenProvider` that caches and refreshes tokens in-process. Configured by four env vars: `GATEWAY_AUTH_ISSUER`, `GATEWAY_AUTH_AUDIENCE`, `GATEWAY_AUTH_CLIENT_ID`, `GATEWAY_AUTH_CLIENT_SECRET`. If any are unset, no auth is attached (preserves local/dev behavior).
12 changes: 12 additions & 0 deletions gcp/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
ANTHROPIC_API_KEY = os.environ["ANTHROPIC_API_KEY"]
OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]
HUGGING_FACE_TOKEN = os.environ["HUGGING_FACE_TOKEN"]
GATEWAY_AUTH_ISSUER = os.environ["GATEWAY_AUTH_ISSUER"]
GATEWAY_AUTH_AUDIENCE = os.environ["GATEWAY_AUTH_AUDIENCE"]
GATEWAY_AUTH_CLIENT_ID = os.environ["GATEWAY_AUTH_CLIENT_ID"]
GATEWAY_AUTH_CLIENT_SECRET = os.environ["GATEWAY_AUTH_CLIENT_SECRET"]

# Export GAE to to .gac.json and DB_PD to .dbpw in the current directory

Expand All @@ -35,6 +39,14 @@
dockerfile = dockerfile.replace(".anthropic_api_key", ANTHROPIC_API_KEY)
dockerfile = dockerfile.replace(".openai_api_key", OPENAI_API_KEY)
dockerfile = dockerfile.replace(".hugging_face_token", HUGGING_FACE_TOKEN)
dockerfile = dockerfile.replace(".gateway_auth_issuer", GATEWAY_AUTH_ISSUER)
dockerfile = dockerfile.replace(".gateway_auth_audience", GATEWAY_AUTH_AUDIENCE)
dockerfile = dockerfile.replace(
".gateway_auth_client_id", GATEWAY_AUTH_CLIENT_ID
)
dockerfile = dockerfile.replace(
".gateway_auth_client_secret", GATEWAY_AUTH_CLIENT_SECRET
)

with open(dockerfile_location, "w") as f:
f.write(dockerfile)
4 changes: 4 additions & 0 deletions gcp/policyengine_api/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ ENV ANTHROPIC_API_KEY .anthropic_api_key
ENV OPENAI_API_KEY .openai_api_key
ENV HUGGING_FACE_TOKEN .hugging_face_token
ENV CREDENTIALS_JSON_API_V2 .credentials_json_api_v2
ENV GATEWAY_AUTH_ISSUER .gateway_auth_issuer
ENV GATEWAY_AUTH_AUDIENCE .gateway_auth_audience
ENV GATEWAY_AUTH_CLIENT_ID .gateway_auth_client_id
ENV GATEWAY_AUTH_CLIENT_SECRET .gateway_auth_client_secret

WORKDIR /app

Expand Down
191 changes: 191 additions & 0 deletions policyengine_api/libs/gateway_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
"""Auth0 client_credentials support for outbound calls to the simulation gateway.

The simulation API gateway (``policyengine-api-v2``) gates its write and
job-status endpoints behind a bearer JWT minted by the PolicyEngine Auth0
tenant. This module fetches that token for the v1 API process, caches it in
memory, and attaches it to every outbound HTTP call via an ``httpx.Auth``
implementation.
"""

from __future__ import annotations

import os
import threading
import time
from typing import Optional

import httpx


GATEWAY_AUTH_ISSUER_ENV = "GATEWAY_AUTH_ISSUER"
GATEWAY_AUTH_AUDIENCE_ENV = "GATEWAY_AUTH_AUDIENCE"
GATEWAY_AUTH_CLIENT_ID_ENV = "GATEWAY_AUTH_CLIENT_ID"
GATEWAY_AUTH_CLIENT_SECRET_ENV = "GATEWAY_AUTH_CLIENT_SECRET"

GATEWAY_AUTH_ENV_VARS = (
GATEWAY_AUTH_ISSUER_ENV,
GATEWAY_AUTH_AUDIENCE_ENV,
GATEWAY_AUTH_CLIENT_ID_ENV,
GATEWAY_AUTH_CLIENT_SECRET_ENV,
)


class GatewayAuthError(RuntimeError):
"""Raised when the gateway auth config is missing or the token fetch fails."""


def _require_all_or_none_gateway_auth_env() -> None:
"""Refuse to start when the four GATEWAY_AUTH_* env vars are partially set.

A typo in one GH Action secret name would otherwise silently degrade to
unauthenticated gateway calls, which is the exact scenario this module
exists to prevent.
"""
present = [name for name in GATEWAY_AUTH_ENV_VARS if os.environ.get(name)]
if present and len(present) != len(GATEWAY_AUTH_ENV_VARS):
missing = [name for name in GATEWAY_AUTH_ENV_VARS if not os.environ.get(name)]
raise GatewayAuthError(
"Gateway auth is partially configured: "
f"{', '.join(present)} set but {', '.join(missing)} missing. "
"Set all four or none."
)


class GatewayAuthTokenProvider:
"""Fetch and cache an Auth0 ``client_credentials`` access token.

The provider is thread-safe and refreshes the token a short window before
its advertised expiry so concurrent workers cannot race and observe an
expired token between validation and use. A single instance can (and
should) be shared across many HTTP clients.
"""

_REFRESH_MARGIN_SECONDS = 60

def __init__(
self,
issuer: Optional[str] = None,
audience: Optional[str] = None,
client_id: Optional[str] = None,
client_secret: Optional[str] = None,
*,
http_timeout: float = 10.0,
):
self._issuer = (
issuer
if issuer is not None
else os.environ.get(GATEWAY_AUTH_ISSUER_ENV, "")
).rstrip("/")
self._audience = (
audience
if audience is not None
else os.environ.get(GATEWAY_AUTH_AUDIENCE_ENV, "")
)
self._client_id = (
client_id
if client_id is not None
else os.environ.get(GATEWAY_AUTH_CLIENT_ID_ENV, "")
)
self._client_secret = (
client_secret
if client_secret is not None
else os.environ.get(GATEWAY_AUTH_CLIENT_SECRET_ENV, "")
)
self._http_timeout = http_timeout
self._token: Optional[str] = None
self._expires_at: float = 0.0
self._lock = threading.Lock()

@property
def configured(self) -> bool:
"""True iff all four required env vars / kwargs were provided."""
return all((self._issuer, self._audience, self._client_id, self._client_secret))

def get_token(self) -> str:
"""Return a valid bearer token, fetching or refreshing as needed."""
if not self.configured:
raise GatewayAuthError(
"Gateway auth not configured: set "
f"{GATEWAY_AUTH_ISSUER_ENV}, {GATEWAY_AUTH_AUDIENCE_ENV}, "
f"{GATEWAY_AUTH_CLIENT_ID_ENV}, and "
f"{GATEWAY_AUTH_CLIENT_SECRET_ENV}."
)
with self._lock:
now = time.time()
if (
self._token is None
or now >= self._expires_at - self._REFRESH_MARGIN_SECONDS
):
self._fetch_locked()
return self._token # type: ignore[return-value]

def _fetch_locked(self) -> None:
"""Call Auth0's ``/oauth/token``. Caller must hold ``_lock``."""
try:
response = httpx.post(
f"{self._issuer}/oauth/token",
json={
"client_id": self._client_id,
"client_secret": self._client_secret,
"audience": self._audience,
"grant_type": "client_credentials",
},
timeout=self._http_timeout,
)
except httpx.RequestError as exc:
raise GatewayAuthError(f"Auth0 token fetch network error: {exc}") from exc

try:
response.raise_for_status()
except httpx.HTTPStatusError as exc:
raise GatewayAuthError(
f"Auth0 token fetch failed: HTTP {response.status_code}"
) from exc

data = response.json()
token = data.get("access_token")
if not token:
raise GatewayAuthError("Auth0 response missing access_token")
# Clamp expires_in so a pathological short/zero value from Auth0
# cannot drive the refresh check into perpetual refetching under
# concurrent load.
raw_expires_in = data.get("expires_in")
if raw_expires_in is None:
raise GatewayAuthError("Auth0 response missing expires_in")
expires_in = max(int(raw_expires_in), self._REFRESH_MARGIN_SECONDS * 2)
self._token = token
self._expires_at = time.time() + expires_in

def invalidate(self) -> None:
"""Drop the cached token so the next ``get_token`` call refetches.

Intended for recovery after a 401 from the gateway (e.g. the Auth0
signing key rotated) rather than routine use.
"""
with self._lock:
self._token = None
self._expires_at = 0.0


class GatewayBearerAuth(httpx.Auth):
"""``httpx.Auth`` adapter that attaches a refreshed bearer token per request.

Implements httpx's two-yield retry contract: on a 401 the cached token is
invalidated and a single retry is made with a freshly fetched token. This
covers the common case of Auth0 rotating its JWKS while a long-lived v1
worker holds a stale token.
"""

def __init__(self, token_provider: GatewayAuthTokenProvider):
self._token_provider = token_provider

def auth_flow(self, request):
request.headers["Authorization"] = f"Bearer {self._token_provider.get_token()}"
response = yield request

if response.status_code != 401:
return

self._token_provider.invalidate()
request.headers["Authorization"] = f"Bearer {self._token_provider.get_token()}"
yield request
24 changes: 23 additions & 1 deletion policyengine_api/libs/simulation_api_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
import httpx

from policyengine_api.gcp_logging import logger
from policyengine_api.libs.gateway_auth import (
GatewayAuthTokenProvider,
GatewayBearerAuth,
_require_all_or_none_gateway_auth_env,
)


@dataclass
Expand Down Expand Up @@ -47,7 +52,24 @@ def __init__(self):
"SIMULATION_API_URL",
"https://policyengine--policyengine-simulation-gateway-web-app.modal.run",
)
self.client = httpx.Client(timeout=30.0)
self._token_provider = GatewayAuthTokenProvider()
_require_all_or_none_gateway_auth_env()
auth = (
GatewayBearerAuth(self._token_provider)
if self._token_provider.configured
else None
)
if auth is None:
logger.log_struct(
{
"message": (
"SimulationAPIModal initialised without gateway auth; "
"all GATEWAY_AUTH_* env vars are unset."
),
},
severity="WARNING",
)
self.client = httpx.Client(timeout=30.0, auth=auth)

def run(self, payload: dict) -> ModalSimulationExecution:
"""
Expand Down
Loading
Loading