diff --git a/app/s3df/compute_adapter.py b/app/s3df/compute_adapter.py index 037152fa..28df2e55 100644 --- a/app/s3df/compute_adapter.py +++ b/app/s3df/compute_adapter.py @@ -37,6 +37,10 @@ ) from fastapi import Response +from ..routers.compute import models as compute_models +from ..types.user import User +from ..routers.status import models as status_models + logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- @@ -226,7 +230,12 @@ def _get_slurm_context(self, user): # -- submit_job --------------------------------------------------------- - async def submit_job(self, resource, user, job_spec) -> dict: + async def submit_job( + self, + resource: status_models.Resource, + user: User, + job_spec: compute_models.JobSpec, + ) -> compute_models.Job: """ POST /compute/job/{resource_id} Maps IRI JobSpec → SlurmV0041PostJobSubmitRequest and submits. @@ -239,48 +248,31 @@ async def submit_job(self, resource, user, job_spec) -> dict: partition = None account = None environment = ["PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"] - name = None - executable = None - cwd = None - stdout = None - stderr = None - if job_spec: - name = getattr(job_spec, "name", None) - executable = getattr(job_spec, "executable", None) - cwd = str(job_spec.directory) if getattr(job_spec, "directory", None) else None - stdout = getattr(job_spec, "stdout_path", None) - stderr = getattr(job_spec, "stderr_path", None) + name = job_spec.name + executable = job_spec.executable + cwd = str(job_spec.directory) if job_spec.directory else None + stdout = job_spec.stdout_path + stderr = job_spec.stderr_path - if getattr(job_spec, "environment", None): - environment = [f"{k}={v}" for k, v in job_spec.environment.items()] + if job_spec.environment: + environment = [f"{k}={v}" for k, v in job_spec.environment.items()] - resources = getattr(job_spec, "resources", None) - if resources: - node_count = getattr(resources, "node_count", 1) or 1 + if job_spec.resources: + node_count = job_spec.resources.node_count or 1 - attributes = getattr(job_spec, "attributes", None) - if attributes: - duration = getattr(attributes, "duration", None) - if duration is not None: - # duration may be timedelta or int (seconds) - total_secs = ( - duration.total_seconds() - if hasattr(duration, "total_seconds") - else float(duration) - ) - duration_mins = max(1, int(total_secs // 60)) - partition = getattr(attributes, "queue_name", None) - account = getattr(attributes, "account", None) + if job_spec.attributes: + if job_spec.attributes.duration is not None: + duration_mins = max(1, int(job_spec.attributes.duration // 60)) + partition = job_spec.attributes.queue_name + account = job_spec.attributes.account partition = partition or os.environ.get("SLURM_DEFAULT_PARTITION") account = account or os.environ.get("SLURM_DEFAULT_ACCOUNT") slurm_job = SlurmV0041PostJobSubmitRequestJob( nodes=str(node_count), - time_limit=SlurmV0041PostJobSubmitRequestJobsInnerTimeLimit( - set=True, number=duration_mins - ), + time_limit=SlurmV0041PostJobSubmitRequestJobsInnerTimeLimit(set=True, number=duration_mins), name=name, script=executable, partition=partition, @@ -292,12 +284,8 @@ async def submit_job(self, resource, user, job_spec) -> dict: ) # Job array support: e.g. custom_attributes={"array": "0-19"} - if job_spec: - attributes = getattr(job_spec, "attributes", None) - if attributes: - ca = getattr(attributes, "custom_attributes", {}) or {} - if "array" in ca: - slurm_job.array = ca["array"] + if job_spec.attributes and "array" in job_spec.attributes.custom_attributes: + slurm_job.array = job_spec.attributes.custom_attributes["array"] req = SlurmV0041PostJobSubmitRequest(job=slurm_job) @@ -307,10 +295,11 @@ async def submit_job(self, resource, user, job_spec) -> dict: _headers=headers, ) logger.info("Job submitted: job_id=%s", resp.job_id) - return { - "id": str(resp.job_id), - "status": {"state": JobState.QUEUED}, - } + return compute_models.Job( + id=str(resp.job_id), + # TODO: check if 200 always mean it is queued + status=compute_models.JobStatus(state=JobState.QUEUED), + ) except ApiException as exc: logger.error("submit_job failed: %s", exc) raise RuntimeError(f"Slurm submission failed: {exc}") from exc @@ -455,7 +444,7 @@ async def get_jobs( jobs = [j for j in jobs if getattr(j, key, None) == value] # Pagination - jobs = jobs[offset: offset + limit] + jobs = jobs[offset : offset + limit] return [_job_from_slurm_info(j, include_spec) for j in jobs]