Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 33 additions & 44 deletions app/s3df/compute_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@
)
from fastapi import Response

from ..routers.compute import models as compute_models
from ..types.user import User
from ..routers.status import models as status_models

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -226,7 +230,12 @@ def _get_slurm_context(self, user):

# -- submit_job ---------------------------------------------------------

async def submit_job(self, resource, user, job_spec) -> dict:
async def submit_job(
self,
resource: status_models.Resource,
user: User,
job_spec: compute_models.JobSpec,
) -> compute_models.Job:
"""
POST /compute/job/{resource_id}
Maps IRI JobSpec → SlurmV0041PostJobSubmitRequest and submits.
Expand All @@ -239,48 +248,31 @@ async def submit_job(self, resource, user, job_spec) -> dict:
partition = None
account = None
environment = ["PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"]
name = None
executable = None
cwd = None
stdout = None
stderr = None

if job_spec:
name = getattr(job_spec, "name", None)
executable = getattr(job_spec, "executable", None)
cwd = str(job_spec.directory) if getattr(job_spec, "directory", None) else None
stdout = getattr(job_spec, "stdout_path", None)
stderr = getattr(job_spec, "stderr_path", None)
name = job_spec.name
executable = job_spec.executable
cwd = str(job_spec.directory) if job_spec.directory else None
stdout = job_spec.stdout_path
stderr = job_spec.stderr_path

if getattr(job_spec, "environment", None):
environment = [f"{k}={v}" for k, v in job_spec.environment.items()]
if job_spec.environment:
environment = [f"{k}={v}" for k, v in job_spec.environment.items()]

resources = getattr(job_spec, "resources", None)
if resources:
node_count = getattr(resources, "node_count", 1) or 1
if job_spec.resources:
node_count = job_spec.resources.node_count or 1

attributes = getattr(job_spec, "attributes", None)
if attributes:
duration = getattr(attributes, "duration", None)
if duration is not None:
# duration may be timedelta or int (seconds)
total_secs = (
duration.total_seconds()
if hasattr(duration, "total_seconds")
else float(duration)
)
duration_mins = max(1, int(total_secs // 60))
partition = getattr(attributes, "queue_name", None)
account = getattr(attributes, "account", None)
if job_spec.attributes:
if job_spec.attributes.duration is not None:
duration_mins = max(1, int(job_spec.attributes.duration // 60))
partition = job_spec.attributes.queue_name
account = job_spec.attributes.account

partition = partition or os.environ.get("SLURM_DEFAULT_PARTITION")
account = account or os.environ.get("SLURM_DEFAULT_ACCOUNT")

slurm_job = SlurmV0041PostJobSubmitRequestJob(
nodes=str(node_count),
time_limit=SlurmV0041PostJobSubmitRequestJobsInnerTimeLimit(
set=True, number=duration_mins
),
time_limit=SlurmV0041PostJobSubmitRequestJobsInnerTimeLimit(set=True, number=duration_mins),
name=name,
script=executable,
partition=partition,
Expand All @@ -292,12 +284,8 @@ async def submit_job(self, resource, user, job_spec) -> dict:
)

# Job array support: e.g. custom_attributes={"array": "0-19"}
if job_spec:
attributes = getattr(job_spec, "attributes", None)
if attributes:
ca = getattr(attributes, "custom_attributes", {}) or {}
if "array" in ca:
slurm_job.array = ca["array"]
if job_spec.attributes and "array" in job_spec.attributes.custom_attributes:
slurm_job.array = job_spec.attributes.custom_attributes["array"]

req = SlurmV0041PostJobSubmitRequest(job=slurm_job)

Expand All @@ -307,10 +295,11 @@ async def submit_job(self, resource, user, job_spec) -> dict:
_headers=headers,
)
logger.info("Job submitted: job_id=%s", resp.job_id)
return {
"id": str(resp.job_id),
"status": {"state": JobState.QUEUED},
}
return compute_models.Job(
id=str(resp.job_id),
# TODO: check if 200 always mean it is queued
status=compute_models.JobStatus(state=JobState.QUEUED),
)
Comment on lines +298 to +302
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep submit_job return shape backward-compatible

Returning a compute_models.Job object here changes submit_job from a mapping-like result to a model instance, which breaks existing direct callers that still index into the result (for example app/s3df/slurm/test_slurmrestd.py uses result["id"] and result["status"]["state"]). In environments that invoke the adapter directly (outside FastAPI response serialization), this will raise at runtime and stop job-submission workflows.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amithmslac this will break some tests, i guess

except ApiException as exc:
logger.error("submit_job failed: %s", exc)
raise RuntimeError(f"Slurm submission failed: {exc}") from exc
Expand Down Expand Up @@ -455,7 +444,7 @@ async def get_jobs(
jobs = [j for j in jobs if getattr(j, key, None) == value]

# Pagination
jobs = jobs[offset: offset + limit]
jobs = jobs[offset : offset + limit]

return [_job_from_slurm_info(j, include_spec) for j in jobs]

Expand Down
Loading