Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions app/demo_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,14 +493,14 @@ async def get_current_user(
raise HTTPException(status_code=401, detail="Invalid API key")
return "gtorok"

async def get_current_user_globus(
async def get_current_user_oidc(
self: "DemoAdapter",
api_key: str,
client_ip: str,
globus_introspect: dict | None,
token_info: dict | None,
) -> str:
"""
Decode the api_key and return the authenticated user's id from information returned by introspecting a globus token.
Decode the api_key and return the authenticated user's id from information returned by an OIDC token.
This method is not called directly, rather authorized endpoints "depend" on it.
(https://fastapi.tiangolo.com/tutorial/dependencies/)
"""
Expand All @@ -511,7 +511,7 @@ async def get_user(
user_id: str,
api_key: str,
client_ip: str | None,
globus_introspect: dict | None,
token_info: dict | None,
) -> User:
if user_id != self.user.id:
raise HTTPException(status_code=403, detail="User not found")
Expand Down
2 changes: 0 additions & 2 deletions app/request_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from . import config

_api_url_base: ContextVar[str | None] = ContextVar("_api_url_base", default=None)


def _first_header_value(value: str | None) -> str:
"""Return the first comma-delimited header value with surrounding whitespace removed."""
return (value or "").split(",")[0].strip()
Expand Down
200 changes: 149 additions & 51 deletions app/routers/iri_router.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,129 @@
from abc import ABC, abstractmethod
import asyncio
import json
import os
import logging
import importlib
import threading
import time
import globus_sdk
from typing import Any
import jwt
from jwt import PyJWKClient
from jwt.exceptions import InvalidTokenError
from urllib.error import URLError
from urllib.request import Request as UrlRequest, urlopen
from fastapi import Request, Depends, HTTPException, APIRouter
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials

from ..types.user import User

bearer_scheme = HTTPBearer()
_DISCOVERY_TIMEOUT_SECONDS = float(os.environ.get("OIDC_DISCOVERY_TIMEOUT_SECONDS", "10"))
_DISCOVERY_CACHE_TTL_SECONDS = float(os.environ.get("OIDC_DISCOVERY_CACHE_TTL_SECONDS", "300"))
_JWKS_CACHE_LIFESPAN_SECONDS = float(os.environ.get("OIDC_JWKS_CACHE_LIFESPAN_SECONDS", "3600"))
_oidc_remote_cache_lock = threading.Lock()
_oidc_remote_cache: dict[str, tuple[float, dict[str, Any], PyJWKClient]] = {}


GLOBUS_RS_ID = os.environ.get("GLOBUS_RS_ID")
GLOBUS_RS_SECRET = os.environ.get("GLOBUS_RS_SECRET")
GLOBUS_RS_SCOPE_SUFFIX = os.environ.get("GLOBUS_RS_SCOPE_SUFFIX")
def _oidc_auth_config() -> dict[str, str] | None:
discovery_uri = os.environ.get("OIDC_DISCOVERY_URI")
client_id = os.environ.get("OIDC_CLIENT_ID")

if not discovery_uri or not client_id:
return None

required_scopes = tuple(
scope
for scope in (
os.environ.get("OIDC_REQUIRED_SCOPES")
or os.environ.get("OIDC_REQUIRED_SCOPE")
or ""
).replace(",", " ").split()
if scope
)

return {
"discovery_uri": discovery_uri,
"client_id": client_id,
"required_scopes": required_scopes,
"required_audience": os.environ.get("OIDC_REQUIRED_AUDIENCE") or client_id,
}


def _fetch_oidc_discovery_document(discovery_uri: str) -> dict[str, Any]:
request = UrlRequest(
discovery_uri,
headers={"Accept": "application/json"},
)
with urlopen(request, timeout=_DISCOVERY_TIMEOUT_SECONDS) as response:
payload = response.read().decode("utf-8")
metadata = json.loads(payload)
jwks_uri = metadata.get("jwks_uri")
if not jwks_uri:
raise RuntimeError("OIDC discovery document is missing jwks_uri")
return metadata


def _load_oidc_remote_state(discovery_uri: str) -> tuple[dict[str, Any], PyJWKClient]:
now = time.time()
cached: tuple[float, dict[str, Any], PyJWKClient] | None = None
with _oidc_remote_cache_lock:
cached = _oidc_remote_cache.get(discovery_uri)
if cached and now - cached[0] < _DISCOVERY_CACHE_TTL_SECONDS:
return cached[1], cached[2]

try:
metadata = _fetch_oidc_discovery_document(discovery_uri)
except Exception:
if cached:
logging.getLogger(__name__).warning(
"OIDC discovery refresh failed for %s; using cached metadata and JWKS client",
discovery_uri,
exc_info=True,
)
return cached[1], cached[2]
raise

with _oidc_remote_cache_lock:
cached = _oidc_remote_cache.get(discovery_uri)
if cached and cached[1].get("jwks_uri") == metadata["jwks_uri"]:
jwks_client = cached[2]
else:
jwks_client = PyJWKClient(
metadata["jwks_uri"],
cache_keys=True,
cache_jwk_set=True,
lifespan=_JWKS_CACHE_LIFESPAN_SECONDS,
timeout=_DISCOVERY_TIMEOUT_SECONDS,
)
_oidc_remote_cache[discovery_uri] = (now, metadata, jwks_client)
return metadata, jwks_client


def _normalize_scope_claim(scope: Any) -> set[str]:
if isinstance(scope, str):
return {item for item in scope.split() if item}
if isinstance(scope, list):
return {str(item) for item in scope if str(item)}
return set()


def _decode_oidc_jwt(
api_key: str,
*,
discovery_uri: str,
required_audience: str,
) -> dict[str, Any]:
metadata, jwks_client = _load_oidc_remote_state(discovery_uri)
signing_key = jwks_client.get_signing_key_from_jwt(api_key)
return jwt.decode(
api_key,
signing_key,
algorithms=None,
audience=required_audience,
issuer=metadata["issuer"],
options={"require": ["exp", "iat", "nbf", "iss"]},
)


def get_client_ip(request: Request) -> str | None:
Expand Down Expand Up @@ -80,45 +189,35 @@ def create_adapter(router_name, router_adapter):
return AdapterClass()


async def get_globus_info(self, api_key: str) -> dict:
"""Returns the linked identities and the session info objects"""
# Introspect the IRI API token using resource server credentials
globus_client = globus_sdk.ConfidentialAppAuthClient(GLOBUS_RS_ID, GLOBUS_RS_SECRET)
# grab identity_set_detail for linked identities and session_info to see how the user logged in
introspect = globus_client.oauth2_token_introspect(api_key, include="identity_set_detail,session_info")
logging.getLogger().info("IRI TOKEN INTROSPECTION:")
logging.getLogger().info(introspect)
if not introspect.get("active"):
raise Exception("Inactive token")

# Check exp (expiration time) claim
exp = introspect.get("exp")
if exp and time.time() >= exp:
raise Exception("Token has expired")
async def get_oidc_token_info(self, api_key: str) -> dict[str, Any]:
"""Validate a bearer JWT against the configured OIDC provider."""
config = _oidc_auth_config()
if not config:
raise RuntimeError("OIDC auth is not configured")

# Check nbf (not before) claim
nbf = introspect.get("nbf")
if nbf and time.time() < nbf:
raise Exception("Token not yet valid")

# Check if token has the required IRI scope
token_scope = introspect.get("scope", "").split()
GLOBUS_SCOPE = f"https://auth.globus.org/scopes/{GLOBUS_RS_ID}/{GLOBUS_RS_SCOPE_SUFFIX}"
if GLOBUS_SCOPE not in token_scope:
raise Exception(f"Token missing required scope: {GLOBUS_SCOPE}")

session_info = introspect.get("session_info")
try:
token_info = await asyncio.to_thread(
_decode_oidc_jwt,
api_key,
discovery_uri=config["discovery_uri"],
required_audience=config["required_audience"],
)
except URLError as exc:
raise RuntimeError(f"OIDC discovery/JWKS request failed: {exc.reason}") from exc
except InvalidTokenError as exc:
raise RuntimeError(f"OIDC JWT validation failed: {exc}") from exc

if not session_info:
raise Exception("No recent login was found in the token (missing session_info). "
"Please re-authenticate to obtain a valid session.")
logging.getLogger().info("PING OIDC JWT VALIDATION CLAIMS:")
logging.getLogger().info(token_info)

authentications = session_info.get("authentications")
if not authentications:
raise Exception("No recent login was found in the token (empty session_info.authentications). "
"Please re-authenticate to obtain a valid session.")
required_scopes = config["required_scopes"]
if required_scopes:
token_scope = _normalize_scope_claim(token_info.get("scope"))
missing_scopes = [scope for scope in required_scopes if scope not in token_scope]
if missing_scopes:
raise Exception(f"Token missing required scopes: {', '.join(missing_scopes)}")

return introspect
return token_info


async def current_user(
Expand All @@ -129,16 +228,16 @@ async def current_user(
token = credentials.credentials
ip_address = get_client_ip(request)
user_id = None
globus_introspect = None
token_info = None
exc_msg = ""
try:
if GLOBUS_RS_ID and GLOBUS_RS_SECRET and GLOBUS_RS_SCOPE_SUFFIX:
if _oidc_auth_config():
try:
globus_introspect = await self.get_globus_info(token)
user_id = await self.adapter.get_current_user_globus(token, ip_address, globus_introspect)
except Exception as globus_exc:
logging.getLogger().exception("Globus error:", exc_info=globus_exc)
exc_msg = f"Globus authentication failed: {str(globus_exc)}. || "
token_info = await self.get_oidc_token_info(token)
user_id = await self.adapter.get_current_user_oidc(token, ip_address, token_info)
except Exception as oidc_exc:
logging.getLogger().exception("OIDC auth error:", exc_info=oidc_exc)
exc_msg = f"OIDC authentication failed: {str(oidc_exc)}. || "
if not user_id:
user_id = await self.adapter.get_current_user(token, ip_address)
except Exception as exc:
Expand All @@ -152,14 +251,13 @@ async def current_user(
user_id=user_id,
api_key=token,
client_ip=ip_address,
globus_introspect=globus_introspect,
token_info=token_info,
)

if not user:
raise HTTPException(status_code=404, detail="User not found")
return user


class AuthenticatedAdapter(ABC):
@abstractmethod
async def get_current_user(self: "AuthenticatedAdapter", api_key: str, client_ip: str | None) -> str:
Expand All @@ -171,16 +269,16 @@ async def get_current_user(self: "AuthenticatedAdapter", api_key: str, client_ip
pass

@abstractmethod
async def get_current_user_globus(self: "AuthenticatedAdapter", api_key: str, client_ip: str | None, globus_introspect: dict | None) -> str:
async def get_current_user_oidc(self: "AuthenticatedAdapter", api_key: str, client_ip: str | None, token_info: dict | None) -> str:
"""
Decode the api_key and return the authenticated user's id from information returned by introspecting a globus token.
Decode the api_key and return the authenticated user's id from information returned by an OIDC token.
This method is not called directly, rather authorized endpoints "depend" on it.
(https://fastapi.tiangolo.com/tutorial/dependencies/)
"""
pass

@abstractmethod
async def get_user(self: "AuthenticatedAdapter", user_id: str, api_key: str, client_ip: str | None, globus_introspect: dict | None) -> User:
async def get_user(self: "AuthenticatedAdapter", user_id: str, api_key: str, client_ip: str | None, token_info: dict | None) -> User:
"""
Retrieve additional user information (name, email, etc.) for the given user_id.
"""
Expand Down
26 changes: 19 additions & 7 deletions local-template.env
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
# globus app credentials
export GLOBUS_APP_ID=<your dev app's id goes here>
export GLOBUS_APP_SECRET=<your dev app's secret goes here>
# optional dev app credentials for local token acquisition tooling
export OIDC_APP_ID=<your dev app's id goes here>
export OIDC_APP_SECRET=<your dev app's secret goes here>

# the resource server's credentials
export GLOBUS_RS_ID=ed3e577d-f7f3-4639-b96e-ff5a8445d699
export GLOBUS_RS_SECRET=<the resource server's secret goes here - ask Gabor to add you to the resource server>
# the client metadata IRI uses for JWT audience validation
export OIDC_CLIENT_ID=<the OIDC client id used by IRI>
export OIDC_DISCOVERY_URI=<the OIDC discovery document URI>

export GLOBUS_RS_SCOPE_SUFFIX=iri_api
# optional: override the JWT audience check (defaults to OIDC_CLIENT_ID)
export OIDC_REQUIRED_AUDIENCE=<optional audience override>

# optional: require specific scopes on accepted access tokens
export OIDC_REQUIRED_SCOPES="openid profile email"

# optional live-test helpers
# OIDC_CLIENT_SECRET is only needed if you want the live test to exchange an auth code.
export OIDC_CLIENT_SECRET=<the OIDC client secret for live auth-code exchange>
export OIDC_TOKEN_ENDPOINT=<the OIDC token endpoint if you want the live test to exchange an auth code>
export OIDC_REDIRECT_URI=urn:ietf:wg:oauth:2.0:oob
export OIDC_AUTHORIZATION_CODE=<paste a one-time auth code here for live testing>
export OIDC_LIVE_ACCESS_TOKEN=<or paste a real access token here to skip the code exchange>
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ dependencies = [
"opentelemetry-instrumentation-fastapi>=0.60b1,<0.61b0",
"opentelemetry-exporter-otlp>=1.39.1,<1.40.0",
"globus-sdk>=4.3.1",
"PyJWT>=2.10.1",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you considered a higher level library like authlib? It might take care of some of the validation code in _decode_oidc_jwt. Take a look here: https://docs.authlib.org/en/stable/oauth2/resource-server/flask.html

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not until now. I think that would be a bigger lift for the codebase. Current IRI needs are basically to validate a signed JWT, check the issuer/aud/exp. I think Authlib provides a full set of server/client tooling. Do we need that?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think minimizing boilerplate is usually a good idea. Here are two options:

  1. don't know much about this one: https://fastapi-oidc.readthedocs.io/en/latest/
  2. the "standard": joserfc with:
    1. import keyset (jwks): https://jose.authlib.org/en/recipes/cheatsheet/#key-sets-jwks
    2. decode/validate the jwt: https://jose.authlib.org/en/recipes/cheatsheet/#decode-verify-token (where the 'key' param is the jwks from the prev. step)

"typer>=0.24.1",
]
[tool.ruff]
Expand Down
Loading
Loading