diff --git a/README.md b/README.md
index c2412665..2c4ece4a 100644
--- a/README.md
+++ b/README.md
@@ -37,6 +37,27 @@ The reference implementation is meant to be customized for your facility's IRI i
### Customizing the business logic for your facility
The IRI API handles the "boilerplate" of setting up the rest API. It delegates to the per-facility business logic via interface definitions. These interfaces are implemented as abstract classes, one per api group (status, account, etc.). Each router directory defines a FacilityAdapter class (eg. [the status adapter](app/routers/status/facility_adapter.py)) that is expected to be implemented by the facility who is exposing an IRI API instance.
+## Forwarded Project Header For Compute Requests
+
+Compute submission and update requests support a trusted forwarded header named `X-IRI-Facility-Project`.
+
+This header is intended for deployments where an upstream trusted component has already resolved the caller's project/account into the facility-native value required by the downstream scheduler or execution system.
+
+When `X-IRI-Facility-Project` is present and valid:
+
+- IRI treats that header value as the effective project/account for the compute request.
+- The downstream compute adapter receives the request as if that value were the facility-native account to use for job submission or update.
+- Implementations may surface that effective value in returned job metadata, scheduler requests, labels, annotations, or similar downstream submission context.
+
+For compute submit/update requests, the effective project/account must be specified in exactly one place:
+
+- `job_spec.attributes.account`, or
+- `X-IRI-Facility-Project`
+
+If both are provided, IRI returns `400 Bad Request`.
+If neither is provided, IRI returns `400 Bad Request`.
+This behavior is specific to compute submission/update handling; read-only endpoints are unchanged.
+
The specific implementations can be specified via the `IRI_API_ADAPTER_*` environment variables. For example the adapter for the `status` api would be given by setting `IRI_API_ADAPTER_status` to the full python module and class implementing `app.routers.status.facility_adapter.FacilityAdapter`. (eg. `IRI_API_ADAPTER_status=myfacility.MyFacilityStatusAdapter`)
As a default implementation, this project supplies the [demo adapter](app/demo_adapter.py) which implements every facility adapter with fake data.
diff --git a/app/demo_adapter.py b/app/demo_adapter.py
index fa9a23ae..11a015b8 100644
--- a/app/demo_adapter.py
+++ b/app/demo_adapter.py
@@ -30,6 +30,7 @@
from .routers.status import models as status_models
from .routers.task import facility_adapter as task_adapter
from .routers.task import models as task_models
+from .request_context import get_iri_facility_project
from .types.models import Capability
from .types.user import User
from .types.scalars import AllocationUnit
@@ -542,6 +543,8 @@ async def submit_job(
user: User,
job_spec: compute_models.JobSpec,
) -> compute_models.Job:
+ facility_project = get_iri_facility_project()
+ account = facility_project or (job_spec.attributes.account if job_spec.attributes else None)
return compute_models.Job(
id="job_123",
status=compute_models.JobStatus(
@@ -549,7 +552,7 @@ async def submit_job(
time=utc_timestamp(),
message="job submitted",
exit_code=0,
- meta_data={"account": "account1"},
+ meta_data={"account": account},
),
)
@@ -560,6 +563,8 @@ async def update_job(
job_spec: compute_models.JobSpec,
job_id: str,
) -> compute_models.Job:
+ facility_project = get_iri_facility_project()
+ account = facility_project or (job_spec.attributes.account if job_spec.attributes else None)
return compute_models.Job(
id=job_id,
status=compute_models.JobStatus(
@@ -567,7 +572,7 @@ async def update_job(
time=utc_timestamp(),
message="job updated",
exit_code=0,
- meta_data={"account": "account1"},
+ meta_data={"account": account},
),
)
diff --git a/app/main.py b/app/main.py
index 1f6faccc..30480024 100644
--- a/app/main.py
+++ b/app/main.py
@@ -17,7 +17,7 @@
from . import config
from .apilogger import configure_logging
-from .request_context import set_api_url_base, _api_url_base
+from .request_context import _api_url_base, _iri_facility_project, set_api_url_base
from app.routers.error_handlers import install_error_handlers
from app.routers.facility import facility
@@ -58,12 +58,14 @@
class _ExternalRequestContextMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
- token = _api_url_base.set(None)
+ url_token = _api_url_base.set(None)
+ facility_project_token = _iri_facility_project.set(None)
try:
set_api_url_base(request)
return await call_next(request)
finally:
- _api_url_base.reset(token)
+ _api_url_base.reset(url_token)
+ _iri_facility_project.reset(facility_project_token)
APP.add_middleware(_ExternalRequestContextMiddleware)
diff --git a/app/request_context.py b/app/request_context.py
index cc8c8828..8f95f072 100644
--- a/app/request_context.py
+++ b/app/request_context.py
@@ -6,6 +6,7 @@
from . import config
_api_url_base: ContextVar[str | None] = ContextVar("_api_url_base", default=None)
+_iri_facility_project: ContextVar[str | None] = ContextVar("_iri_facility_project", default=None)
def _first_header_value(value: str | None) -> str:
@@ -22,6 +23,8 @@ def set_api_url_base(request: Request) -> None:
api_url = config.API_URL.strip("/")
if host:
_api_url_base.set(f"{proto}://{host}{prefix}{api_prefix}/{api_url}")
+ facility_project = _first_header_value(request.headers.get("x-iri-facility-project"))
+ _iri_facility_project.set(facility_project or None)
def get_url_prefix() -> str:
@@ -30,3 +33,8 @@ def get_url_prefix() -> str:
if value:
return value
return f"{config.API_URL_ROOT}{config.API_PREFIX}{config.API_URL}"
+
+
+def get_iri_facility_project() -> str | None:
+ """Return the facility-native project/account identifier forwarded by RIG."""
+ return _iri_facility_project.get()
diff --git a/app/routers/compute/compute.py b/app/routers/compute/compute.py
index 71c80b45..08507d6a 100644
--- a/app/routers/compute/compute.py
+++ b/app/routers/compute/compute.py
@@ -1,6 +1,6 @@
"""Compute resource API router"""
-from fastapi import Depends, HTTPException, Query, Request, status
+from fastapi import Depends, Query, Request, status
from ...types.http import forbidExtraQueryParams
from ...types.scalars import StrictHTTPBool
@@ -16,8 +16,6 @@
prefix="/compute",
tags=["compute"],
)
-
-
@router.post(
"/job/{resource_id:str}",
response_model=models.Job,
@@ -31,6 +29,7 @@ async def submit_job(
job_spec: models.JobSpec,
request: Request,
user: User = Depends(router.current_user),
+ project_name: str | None = Depends(router.iri_header_project),
_forbid=Depends(forbidExtraQueryParams()),
):
"""
@@ -38,6 +37,12 @@ async def submit_job(
- **resource**: the name of the compute resource to use
- **job_request**: a PSIJ job spec as defined here
+ - **project/account resolution**:
+ The effective project/account for the submission must be supplied in exactly one place:
+ `job_spec.attributes.account` or the trusted `X-IRI-Facility-Project` request header.
+ If the forwarded header is present and valid, IRI treats its value as the effective facility-native project/account
+ for the downstream submission and related job metadata. If both sources are present, or neither is present,
+ the request is rejected with `400 Bad Request`.
This command will attempt to submit a job and return its id.
"""
@@ -63,6 +68,7 @@ async def update_job(
job_spec: models.JobSpec,
request: Request,
user: User = Depends(router.current_user),
+ project_name: str | None = Depends(router.iri_header_project),
_forbid=Depends(forbidExtraQueryParams()),
):
"""
@@ -71,6 +77,12 @@ async def update_job(
- **resource**: the name of the compute resource to use
- **job_request**: a PSIJ job spec as defined here
+ - **project/account resolution**:
+ The effective project/account for the update must be supplied in exactly one place:
+ `job_spec.attributes.account` or the trusted `X-IRI-Facility-Project` request header.
+ If the forwarded header is present and valid, IRI treats its value as the effective facility-native project/account
+ for downstream update handling and job metadata. If both sources are present, or neither is present,
+ the request is rejected with `400 Bad Request`.
"""
# look up the resource (todo: maybe ensure it's available)
diff --git a/app/routers/compute/models.py b/app/routers/compute/models.py
index cea26492..7c803d17 100644
--- a/app/routers/compute/models.py
+++ b/app/routers/compute/models.py
@@ -28,7 +28,16 @@ class JobAttributes(IRIBaseModel):
duration: int|None = Field(default=None, description="Duration in seconds", ge=1, examples=[30, 60, 120])
queue_name: str|None = Field(default=None, min_length=1, description="Name of the queue or partition to submit the job to", example="debug")
- account: str|None = Field(default=None, min_length=1, description="Account or project to charge for resource usage", example="proj123")
+ account: str|None = Field(
+ default=None,
+ min_length=1,
+ description=(
+ "Account or project to charge for resource usage. "
+ "For compute submission/update requests, specify this here only when the caller is not relying on a trusted forwarded "
+ "`X-IRI-Facility-Project` header. If that header is present and valid, this field must be omitted."
+ ),
+ example="proj123",
+ )
reservation_id: str|None = Field(default=None, min_length=1, description="ID of a reservation to use for the job", example="resv-42")
custom_attributes: dict[str, str] = Field(default_factory=dict, description="Custom scheduler-specific attributes as key-value pairs", example={"constraint": "gpu"})
@@ -79,7 +88,14 @@ class JobSpec(IRIBaseModel):
stdout_path: str|None = Field(default=None, min_length=1, description="Path to file to write standard output", example="/home/user/output.txt")
stderr_path: str|None = Field(default=None, min_length=1, description="Path to file to write standard error", example="/home/user/error.txt")
resources: ResourceSpec|None = Field(default=None, description="Resource requirements for the job")
- attributes: JobAttributes|None = Field(default=None, description="Additional job attributes such as duration, queue, and account")
+ attributes: JobAttributes|None = Field(
+ default=None,
+ description=(
+ "Additional job attributes such as duration, queue, and account. "
+ "For compute submission/update, the effective project/account must be supplied in exactly one place: "
+ "`attributes.account` or the trusted `X-IRI-Facility-Project` request header."
+ ),
+ )
pre_launch: str|None = Field(default=None, min_length=1, description="Script or commands to run before launching the job", example="module load cuda")
post_launch: str|None = Field(default=None, min_length=1, description="Script or commands to run after the job completes", example="echo done")
launcher: str|None = Field(default=None, min_length=1, description="Job launcher to use (e.g., 'mpirun', 'srun')", example="srun")
diff --git a/app/routers/iri_router.py b/app/routers/iri_router.py
index 8abcc4b5..0542193f 100644
--- a/app/routers/iri_router.py
+++ b/app/routers/iri_router.py
@@ -3,10 +3,12 @@
import logging
import importlib
import time
+from typing import Any
import globus_sdk
-from fastapi import Request, Depends, HTTPException, APIRouter
+from fastapi import Body, Request, Depends, HTTPException, APIRouter
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
+from ..request_context import get_iri_facility_project
from ..types.user import User
bearer_scheme = HTTPBearer()
@@ -159,6 +161,29 @@ async def current_user(
raise HTTPException(status_code=404, detail="User not found")
return user
+ async def iri_header_project(self, request: Request, job_spec: dict[str, Any] | None = Body(default=None)) -> str | None:
+ """Expose and validate the forwarded facility-project header for compute routes."""
+ project_name = get_iri_facility_project()
+ spec_account = None
+ if job_spec is not None:
+ attributes = job_spec.get("attributes")
+ if isinstance(attributes, dict):
+ spec_account = attributes.get("account")
+ elif attributes is not None:
+ # Leave malformed body handling to FastAPI/Pydantic validation.
+ return project_name
+ if spec_account and project_name:
+ raise HTTPException(
+ status_code=400,
+ detail="Specify project/account in exactly one place: job_spec.attributes.account or X-IRI-Facility-Project, not both.",
+ )
+ if not spec_account and not project_name:
+ raise HTTPException(
+ status_code=400,
+ detail="Project/account must be specified in exactly one place: job_spec.attributes.account or X-IRI-Facility-Project.",
+ )
+ return project_name
+
class AuthenticatedAdapter(ABC):
@abstractmethod
diff --git a/test/test_facility_project_header.py b/test/test_facility_project_header.py
new file mode 100644
index 00000000..d2a79d4a
--- /dev/null
+++ b/test/test_facility_project_header.py
@@ -0,0 +1,145 @@
+#!/usr/bin/env python3
+"""Regression tests for facility-project header propagation into compute submission."""
+
+import os
+import unittest
+
+from fastapi.testclient import TestClient
+
+os.environ.setdefault("IRI_SHOW_MISSING_ROUTES", "true")
+
+from app.main import APP
+
+
+class FacilityProjectHeaderTests(unittest.TestCase):
+ def test_compute_submit_uses_forwarded_facility_project_header(self):
+ client = TestClient(APP)
+
+ resources_response = client.get("/api/v1/status/resources")
+ self.assertEqual(resources_response.status_code, 200)
+ resource_id = resources_response.json()[0]["id"]
+
+ response = client.post(
+ f"/api/v1/compute/job/{resource_id}",
+ headers={
+ "authorization": "Bearer 12345",
+ "x-iri-facility-project": "ns011",
+ },
+ json={"executable": "/bin/echo", "arguments": ["hello"]},
+ )
+
+ self.assertEqual(response.status_code, 200)
+ body = response.json()
+ self.assertEqual(body["status"]["meta_data"]["account"], "ns011")
+
+ def test_compute_submit_uses_job_spec_account_when_header_absent(self):
+ client = TestClient(APP)
+
+ resources_response = client.get("/api/v1/status/resources")
+ self.assertEqual(resources_response.status_code, 200)
+ resource_id = resources_response.json()[0]["id"]
+
+ response = client.post(
+ f"/api/v1/compute/job/{resource_id}",
+ headers={"authorization": "Bearer 12345"},
+ json={
+ "executable": "/bin/echo",
+ "arguments": ["hello"],
+ "attributes": {"account": "ns011"},
+ },
+ )
+
+ self.assertEqual(response.status_code, 200)
+ body = response.json()
+ self.assertEqual(body["status"]["meta_data"]["account"], "ns011")
+
+ def test_compute_submit_rejects_missing_project_account(self):
+ client = TestClient(APP)
+
+ resources_response = client.get("/api/v1/status/resources")
+ self.assertEqual(resources_response.status_code, 200)
+ resource_id = resources_response.json()[0]["id"]
+
+ response = client.post(
+ f"/api/v1/compute/job/{resource_id}",
+ headers={"authorization": "Bearer 12345"},
+ json={"executable": "/bin/echo", "arguments": ["hello"]},
+ )
+
+ self.assertEqual(response.status_code, 400)
+ self.assertIn("exactly one place", response.json()["detail"])
+
+ def test_compute_submit_rejects_duplicate_project_account_sources(self):
+ client = TestClient(APP)
+
+ resources_response = client.get("/api/v1/status/resources")
+ self.assertEqual(resources_response.status_code, 200)
+ resource_id = resources_response.json()[0]["id"]
+
+ response = client.post(
+ f"/api/v1/compute/job/{resource_id}",
+ headers={
+ "authorization": "Bearer 12345",
+ "x-iri-facility-project": "ns011",
+ },
+ json={
+ "executable": "/bin/echo",
+ "arguments": ["hello"],
+ "attributes": {"account": "also-present"},
+ },
+ )
+
+ self.assertEqual(response.status_code, 400)
+ self.assertIn("not both", response.json()["detail"])
+
+ def test_compute_submit_requires_authorization_before_project_validation(self):
+ client = TestClient(APP)
+
+ response = client.post(
+ "/api/v1/compute/job/0",
+ json={"executable": "/bin/echo", "arguments": ["hello"]},
+ )
+
+ self.assertEqual(response.status_code, 401)
+
+ def test_compute_update_requires_authorization_before_project_validation(self):
+ client = TestClient(APP)
+
+ response = client.put(
+ "/api/v1/compute/job/0/0",
+ json={"executable": "/bin/echo", "arguments": ["hello"]},
+ )
+
+ self.assertEqual(response.status_code, 401)
+
+ def test_compute_submit_malformed_attributes_does_not_500(self):
+ client = TestClient(APP)
+ resources_response = client.get("/api/v1/status/resources")
+ self.assertEqual(resources_response.status_code, 200)
+ resource_id = resources_response.json()[0]["id"]
+
+ response = client.post(
+ f"/api/v1/compute/job/{resource_id}",
+ headers={"authorization": "Bearer 12345"},
+ json={"executable": "/bin/echo", "arguments": ["hello"], "attributes": [None, None]},
+ )
+
+ self.assertIn(response.status_code, {400, 422})
+
+ def test_compute_update_malformed_attributes_does_not_500(self):
+ client = TestClient(APP)
+ resources_response = client.get("/api/v1/status/resources")
+ self.assertEqual(resources_response.status_code, 200)
+ resource_id = resources_response.json()[0]["id"]
+
+ response = client.put(
+ f"/api/v1/compute/job/{resource_id}/0",
+ headers={"authorization": "Bearer 12345"},
+ json={"executable": "/bin/echo", "arguments": ["hello"], "attributes": [None, None]},
+ )
+
+ self.assertIn(response.status_code, {400, 422})
+
+
+if __name__ == "__main__":
+ unittest.main()