diff --git a/app/s3df/compute_adapter.py b/app/s3df/compute_adapter.py index 47978cd0..e7450395 100644 --- a/app/s3df/compute_adapter.py +++ b/app/s3df/compute_adapter.py @@ -35,6 +35,9 @@ from slurmrestd_client.models.slurm_v0041_post_job_submit_request_jobs_inner_time_limit import ( SlurmV0041PostJobSubmitRequestJobsInnerTimeLimit, ) +from slurmrestd_client.models.slurm_v0041_post_job_submit_request_jobs_inner_memory_per_cpu import ( + SlurmV0041PostJobSubmitRequestJobsInnerMemoryPerCpu, +) from fastapi import HTTPException, Response from pydantic import ConfigDict, ValidationError @@ -257,14 +260,23 @@ async def submit_job( # --- resource fields with safe defaults --- node_count = 1 + tasks = None + tasks_per_node = None + cpus_per_task = None + tres_per_task = None + exclusive = ["true"] + memory_per_node = None duration_mins = 60 partition = None account = None + reservation = None environment = ["PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"] name = job_spec.name executable = job_spec.executable + argv = job_spec.arguments or None cwd = str(job_spec.directory) if job_spec.directory else None + stdin = job_spec.stdin_path stdout = job_spec.stdout_path stderr = job_spec.stderr_path @@ -273,12 +285,23 @@ async def submit_job( if job_spec.resources: node_count = job_spec.resources.node_count or 1 + tasks = job_spec.resources.process_count + tasks_per_node = job_spec.resources.processes_per_node + cpus_per_task = job_spec.resources.cpu_cores_per_process + if job_spec.resources.gpu_cores_per_process: + tres_per_task = f"gres/gpu:{job_spec.resources.gpu_cores_per_process}" + if not job_spec.resources.exclusive_node_use: + exclusive = ["false"] + if job_spec.resources.memory: + memory_mb = max(1, job_spec.resources.memory // (1024 * 1024)) + memory_per_node = SlurmV0041PostJobSubmitRequestJobsInnerMemoryPerCpu(set=True, number=memory_mb) if job_spec.attributes: if job_spec.attributes.duration is not None: duration_mins = max(1, int(job_spec.attributes.duration // 60)) partition = job_spec.attributes.queue_name account = job_spec.attributes.account + reservation = job_spec.attributes.reservation_id partition = partition or os.environ.get("SLURM_DEFAULT_PARTITION") account = account or os.environ.get("SLURM_DEFAULT_ACCOUNT") @@ -288,13 +311,22 @@ async def submit_job( try: slurm_job = SlurmV0041PostJobSubmitRequestJobStrict( nodes=str(node_count), + tasks=tasks, + tasks_per_node=tasks_per_node, + cpus_per_task=cpus_per_task, + tres_per_task=tres_per_task, + exclusive=exclusive, + memory_per_node=memory_per_node, time_limit=SlurmV0041PostJobSubmitRequestJobsInnerTimeLimit(set=True, number=duration_mins), name=name, script=executable, + argv=argv, partition=partition, account=account, + reservation=reservation, environment=environment, current_working_directory=cwd, + standard_input=stdin, standard_output=stdout, standard_error=stderr, **custom_attributes