diff --git a/README.md b/README.md index c2412665..2c4ece4a 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,27 @@ The reference implementation is meant to be customized for your facility's IRI i ### Customizing the business logic for your facility The IRI API handles the "boilerplate" of setting up the rest API. It delegates to the per-facility business logic via interface definitions. These interfaces are implemented as abstract classes, one per api group (status, account, etc.). Each router directory defines a FacilityAdapter class (eg. [the status adapter](app/routers/status/facility_adapter.py)) that is expected to be implemented by the facility who is exposing an IRI API instance. +## Forwarded Project Header For Compute Requests + +Compute submission and update requests support a trusted forwarded header named `X-IRI-Facility-Project`. + +This header is intended for deployments where an upstream trusted component has already resolved the caller's project/account into the facility-native value required by the downstream scheduler or execution system. + +When `X-IRI-Facility-Project` is present and valid: + +- IRI treats that header value as the effective project/account for the compute request. +- The downstream compute adapter receives the request as if that value were the facility-native account to use for job submission or update. +- Implementations may surface that effective value in returned job metadata, scheduler requests, labels, annotations, or similar downstream submission context. + +For compute submit/update requests, the effective project/account must be specified in exactly one place: + +- `job_spec.attributes.account`, or +- `X-IRI-Facility-Project` + +If both are provided, IRI returns `400 Bad Request`. +If neither is provided, IRI returns `400 Bad Request`. +This behavior is specific to compute submission/update handling; read-only endpoints are unchanged. + The specific implementations can be specified via the `IRI_API_ADAPTER_*` environment variables. For example the adapter for the `status` api would be given by setting `IRI_API_ADAPTER_status` to the full python module and class implementing `app.routers.status.facility_adapter.FacilityAdapter`. (eg. `IRI_API_ADAPTER_status=myfacility.MyFacilityStatusAdapter`) As a default implementation, this project supplies the [demo adapter](app/demo_adapter.py) which implements every facility adapter with fake data. diff --git a/app/demo_adapter.py b/app/demo_adapter.py index fa9a23ae..11a015b8 100644 --- a/app/demo_adapter.py +++ b/app/demo_adapter.py @@ -30,6 +30,7 @@ from .routers.status import models as status_models from .routers.task import facility_adapter as task_adapter from .routers.task import models as task_models +from .request_context import get_iri_facility_project from .types.models import Capability from .types.user import User from .types.scalars import AllocationUnit @@ -542,6 +543,8 @@ async def submit_job( user: User, job_spec: compute_models.JobSpec, ) -> compute_models.Job: + facility_project = get_iri_facility_project() + account = facility_project or (job_spec.attributes.account if job_spec.attributes else None) return compute_models.Job( id="job_123", status=compute_models.JobStatus( @@ -549,7 +552,7 @@ async def submit_job( time=utc_timestamp(), message="job submitted", exit_code=0, - meta_data={"account": "account1"}, + meta_data={"account": account}, ), ) @@ -560,6 +563,8 @@ async def update_job( job_spec: compute_models.JobSpec, job_id: str, ) -> compute_models.Job: + facility_project = get_iri_facility_project() + account = facility_project or (job_spec.attributes.account if job_spec.attributes else None) return compute_models.Job( id=job_id, status=compute_models.JobStatus( @@ -567,7 +572,7 @@ async def update_job( time=utc_timestamp(), message="job updated", exit_code=0, - meta_data={"account": "account1"}, + meta_data={"account": account}, ), ) diff --git a/app/main.py b/app/main.py index 1f6faccc..30480024 100644 --- a/app/main.py +++ b/app/main.py @@ -17,7 +17,7 @@ from . import config from .apilogger import configure_logging -from .request_context import set_api_url_base, _api_url_base +from .request_context import _api_url_base, _iri_facility_project, set_api_url_base from app.routers.error_handlers import install_error_handlers from app.routers.facility import facility @@ -58,12 +58,14 @@ class _ExternalRequestContextMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): - token = _api_url_base.set(None) + url_token = _api_url_base.set(None) + facility_project_token = _iri_facility_project.set(None) try: set_api_url_base(request) return await call_next(request) finally: - _api_url_base.reset(token) + _api_url_base.reset(url_token) + _iri_facility_project.reset(facility_project_token) APP.add_middleware(_ExternalRequestContextMiddleware) diff --git a/app/request_context.py b/app/request_context.py index cc8c8828..8f95f072 100644 --- a/app/request_context.py +++ b/app/request_context.py @@ -6,6 +6,7 @@ from . import config _api_url_base: ContextVar[str | None] = ContextVar("_api_url_base", default=None) +_iri_facility_project: ContextVar[str | None] = ContextVar("_iri_facility_project", default=None) def _first_header_value(value: str | None) -> str: @@ -22,6 +23,8 @@ def set_api_url_base(request: Request) -> None: api_url = config.API_URL.strip("/") if host: _api_url_base.set(f"{proto}://{host}{prefix}{api_prefix}/{api_url}") + facility_project = _first_header_value(request.headers.get("x-iri-facility-project")) + _iri_facility_project.set(facility_project or None) def get_url_prefix() -> str: @@ -30,3 +33,8 @@ def get_url_prefix() -> str: if value: return value return f"{config.API_URL_ROOT}{config.API_PREFIX}{config.API_URL}" + + +def get_iri_facility_project() -> str | None: + """Return the facility-native project/account identifier forwarded by RIG.""" + return _iri_facility_project.get() diff --git a/app/routers/compute/compute.py b/app/routers/compute/compute.py index 71c80b45..08507d6a 100644 --- a/app/routers/compute/compute.py +++ b/app/routers/compute/compute.py @@ -1,6 +1,6 @@ """Compute resource API router""" -from fastapi import Depends, HTTPException, Query, Request, status +from fastapi import Depends, Query, Request, status from ...types.http import forbidExtraQueryParams from ...types.scalars import StrictHTTPBool @@ -16,8 +16,6 @@ prefix="/compute", tags=["compute"], ) - - @router.post( "/job/{resource_id:str}", response_model=models.Job, @@ -31,6 +29,7 @@ async def submit_job( job_spec: models.JobSpec, request: Request, user: User = Depends(router.current_user), + project_name: str | None = Depends(router.iri_header_project), _forbid=Depends(forbidExtraQueryParams()), ): """ @@ -38,6 +37,12 @@ async def submit_job( - **resource**: the name of the compute resource to use - **job_request**: a PSIJ job spec as defined here + - **project/account resolution**: + The effective project/account for the submission must be supplied in exactly one place: + `job_spec.attributes.account` or the trusted `X-IRI-Facility-Project` request header. + If the forwarded header is present and valid, IRI treats its value as the effective facility-native project/account + for the downstream submission and related job metadata. If both sources are present, or neither is present, + the request is rejected with `400 Bad Request`. This command will attempt to submit a job and return its id. """ @@ -63,6 +68,7 @@ async def update_job( job_spec: models.JobSpec, request: Request, user: User = Depends(router.current_user), + project_name: str | None = Depends(router.iri_header_project), _forbid=Depends(forbidExtraQueryParams()), ): """ @@ -71,6 +77,12 @@ async def update_job( - **resource**: the name of the compute resource to use - **job_request**: a PSIJ job spec as defined here + - **project/account resolution**: + The effective project/account for the update must be supplied in exactly one place: + `job_spec.attributes.account` or the trusted `X-IRI-Facility-Project` request header. + If the forwarded header is present and valid, IRI treats its value as the effective facility-native project/account + for downstream update handling and job metadata. If both sources are present, or neither is present, + the request is rejected with `400 Bad Request`. """ # look up the resource (todo: maybe ensure it's available) diff --git a/app/routers/compute/models.py b/app/routers/compute/models.py index cea26492..7c803d17 100644 --- a/app/routers/compute/models.py +++ b/app/routers/compute/models.py @@ -28,7 +28,16 @@ class JobAttributes(IRIBaseModel): duration: int|None = Field(default=None, description="Duration in seconds", ge=1, examples=[30, 60, 120]) queue_name: str|None = Field(default=None, min_length=1, description="Name of the queue or partition to submit the job to", example="debug") - account: str|None = Field(default=None, min_length=1, description="Account or project to charge for resource usage", example="proj123") + account: str|None = Field( + default=None, + min_length=1, + description=( + "Account or project to charge for resource usage. " + "For compute submission/update requests, specify this here only when the caller is not relying on a trusted forwarded " + "`X-IRI-Facility-Project` header. If that header is present and valid, this field must be omitted." + ), + example="proj123", + ) reservation_id: str|None = Field(default=None, min_length=1, description="ID of a reservation to use for the job", example="resv-42") custom_attributes: dict[str, str] = Field(default_factory=dict, description="Custom scheduler-specific attributes as key-value pairs", example={"constraint": "gpu"}) @@ -79,7 +88,14 @@ class JobSpec(IRIBaseModel): stdout_path: str|None = Field(default=None, min_length=1, description="Path to file to write standard output", example="/home/user/output.txt") stderr_path: str|None = Field(default=None, min_length=1, description="Path to file to write standard error", example="/home/user/error.txt") resources: ResourceSpec|None = Field(default=None, description="Resource requirements for the job") - attributes: JobAttributes|None = Field(default=None, description="Additional job attributes such as duration, queue, and account") + attributes: JobAttributes|None = Field( + default=None, + description=( + "Additional job attributes such as duration, queue, and account. " + "For compute submission/update, the effective project/account must be supplied in exactly one place: " + "`attributes.account` or the trusted `X-IRI-Facility-Project` request header." + ), + ) pre_launch: str|None = Field(default=None, min_length=1, description="Script or commands to run before launching the job", example="module load cuda") post_launch: str|None = Field(default=None, min_length=1, description="Script or commands to run after the job completes", example="echo done") launcher: str|None = Field(default=None, min_length=1, description="Job launcher to use (e.g., 'mpirun', 'srun')", example="srun") diff --git a/app/routers/iri_router.py b/app/routers/iri_router.py index 8abcc4b5..0542193f 100644 --- a/app/routers/iri_router.py +++ b/app/routers/iri_router.py @@ -3,10 +3,12 @@ import logging import importlib import time +from typing import Any import globus_sdk -from fastapi import Request, Depends, HTTPException, APIRouter +from fastapi import Body, Request, Depends, HTTPException, APIRouter from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from ..request_context import get_iri_facility_project from ..types.user import User bearer_scheme = HTTPBearer() @@ -159,6 +161,29 @@ async def current_user( raise HTTPException(status_code=404, detail="User not found") return user + async def iri_header_project(self, request: Request, job_spec: dict[str, Any] | None = Body(default=None)) -> str | None: + """Expose and validate the forwarded facility-project header for compute routes.""" + project_name = get_iri_facility_project() + spec_account = None + if job_spec is not None: + attributes = job_spec.get("attributes") + if isinstance(attributes, dict): + spec_account = attributes.get("account") + elif attributes is not None: + # Leave malformed body handling to FastAPI/Pydantic validation. + return project_name + if spec_account and project_name: + raise HTTPException( + status_code=400, + detail="Specify project/account in exactly one place: job_spec.attributes.account or X-IRI-Facility-Project, not both.", + ) + if not spec_account and not project_name: + raise HTTPException( + status_code=400, + detail="Project/account must be specified in exactly one place: job_spec.attributes.account or X-IRI-Facility-Project.", + ) + return project_name + class AuthenticatedAdapter(ABC): @abstractmethod diff --git a/test/test_facility_project_header.py b/test/test_facility_project_header.py new file mode 100644 index 00000000..d2a79d4a --- /dev/null +++ b/test/test_facility_project_header.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python3 +"""Regression tests for facility-project header propagation into compute submission.""" + +import os +import unittest + +from fastapi.testclient import TestClient + +os.environ.setdefault("IRI_SHOW_MISSING_ROUTES", "true") + +from app.main import APP + + +class FacilityProjectHeaderTests(unittest.TestCase): + def test_compute_submit_uses_forwarded_facility_project_header(self): + client = TestClient(APP) + + resources_response = client.get("/api/v1/status/resources") + self.assertEqual(resources_response.status_code, 200) + resource_id = resources_response.json()[0]["id"] + + response = client.post( + f"/api/v1/compute/job/{resource_id}", + headers={ + "authorization": "Bearer 12345", + "x-iri-facility-project": "ns011", + }, + json={"executable": "/bin/echo", "arguments": ["hello"]}, + ) + + self.assertEqual(response.status_code, 200) + body = response.json() + self.assertEqual(body["status"]["meta_data"]["account"], "ns011") + + def test_compute_submit_uses_job_spec_account_when_header_absent(self): + client = TestClient(APP) + + resources_response = client.get("/api/v1/status/resources") + self.assertEqual(resources_response.status_code, 200) + resource_id = resources_response.json()[0]["id"] + + response = client.post( + f"/api/v1/compute/job/{resource_id}", + headers={"authorization": "Bearer 12345"}, + json={ + "executable": "/bin/echo", + "arguments": ["hello"], + "attributes": {"account": "ns011"}, + }, + ) + + self.assertEqual(response.status_code, 200) + body = response.json() + self.assertEqual(body["status"]["meta_data"]["account"], "ns011") + + def test_compute_submit_rejects_missing_project_account(self): + client = TestClient(APP) + + resources_response = client.get("/api/v1/status/resources") + self.assertEqual(resources_response.status_code, 200) + resource_id = resources_response.json()[0]["id"] + + response = client.post( + f"/api/v1/compute/job/{resource_id}", + headers={"authorization": "Bearer 12345"}, + json={"executable": "/bin/echo", "arguments": ["hello"]}, + ) + + self.assertEqual(response.status_code, 400) + self.assertIn("exactly one place", response.json()["detail"]) + + def test_compute_submit_rejects_duplicate_project_account_sources(self): + client = TestClient(APP) + + resources_response = client.get("/api/v1/status/resources") + self.assertEqual(resources_response.status_code, 200) + resource_id = resources_response.json()[0]["id"] + + response = client.post( + f"/api/v1/compute/job/{resource_id}", + headers={ + "authorization": "Bearer 12345", + "x-iri-facility-project": "ns011", + }, + json={ + "executable": "/bin/echo", + "arguments": ["hello"], + "attributes": {"account": "also-present"}, + }, + ) + + self.assertEqual(response.status_code, 400) + self.assertIn("not both", response.json()["detail"]) + + def test_compute_submit_requires_authorization_before_project_validation(self): + client = TestClient(APP) + + response = client.post( + "/api/v1/compute/job/0", + json={"executable": "/bin/echo", "arguments": ["hello"]}, + ) + + self.assertEqual(response.status_code, 401) + + def test_compute_update_requires_authorization_before_project_validation(self): + client = TestClient(APP) + + response = client.put( + "/api/v1/compute/job/0/0", + json={"executable": "/bin/echo", "arguments": ["hello"]}, + ) + + self.assertEqual(response.status_code, 401) + + def test_compute_submit_malformed_attributes_does_not_500(self): + client = TestClient(APP) + resources_response = client.get("/api/v1/status/resources") + self.assertEqual(resources_response.status_code, 200) + resource_id = resources_response.json()[0]["id"] + + response = client.post( + f"/api/v1/compute/job/{resource_id}", + headers={"authorization": "Bearer 12345"}, + json={"executable": "/bin/echo", "arguments": ["hello"], "attributes": [None, None]}, + ) + + self.assertIn(response.status_code, {400, 422}) + + def test_compute_update_malformed_attributes_does_not_500(self): + client = TestClient(APP) + resources_response = client.get("/api/v1/status/resources") + self.assertEqual(resources_response.status_code, 200) + resource_id = resources_response.json()[0]["id"] + + response = client.put( + f"/api/v1/compute/job/{resource_id}/0", + headers={"authorization": "Bearer 12345"}, + json={"executable": "/bin/echo", "arguments": ["hello"], "attributes": [None, None]}, + ) + + self.assertIn(response.status_code, {400, 422}) + + +if __name__ == "__main__": + unittest.main()