diff --git a/src/dvsim/launcher/slurm.py b/src/dvsim/launcher/slurm.py index 12d36b41..3fe61b7c 100644 --- a/src/dvsim/launcher/slurm.py +++ b/src/dvsim/launcher/slurm.py @@ -34,6 +34,7 @@ def __init__(self, deploy) -> None: # Popen object when launching the job. self.process = None + self._log_file = None self.slurm_log_file = f"{self.job_spec.log_path}.slurm" def _do_launch(self) -> None: @@ -57,34 +58,35 @@ def _do_launch(self) -> None: slurm_setup_cmd += ";" # Encapsulate the run command with the slurm invocation + full_cmd = f"{slurm_setup_cmd} {self.job_spec.cmd}".strip() slurm_cmd = ( - f"srun -p {SLURM_QUEUE} --mem={SLURM_MEM} --mincpus={SLURM_MINCPUS} " + f"srun -p {SLURM_QUEUE} --mem={SLURM_MEM} " + f"--mincpus={SLURM_MINCPUS} " f"--time={SLURM_TIMEOUT} --cpus-per-task={SLURM_CPUS_PER_TASK} " - f'bash -c "{slurm_setup_cmd} {self.job_spec.cmd}"' + f"bash -c {shlex.quote(full_cmd)}" ) try: - with pathlib.Path(self.slurm_log_file).open("w") as out_file: - out_file.write(f"[Executing]:\n{self.job_spec.cmd}\n\n") - out_file.flush() - - log.info(f"Executing slurm command: {slurm_cmd}") - self.process = subprocess.Popen( - shlex.split(slurm_cmd), - bufsize=4096, - universal_newlines=True, - stdout=out_file, - stderr=out_file, - env=exports, - ) + self._log_file = pathlib.Path(self.slurm_log_file).open("w") + self._log_file.write(f"[Executing]:\n{self.job_spec.cmd}\n\n") + self._log_file.flush() + + log.info(f"Executing slurm command: {slurm_cmd}") + self.process = subprocess.Popen( + shlex.split(slurm_cmd), + text=True, + stdout=self._log_file, + stderr=self._log_file, + env=exports, + ) except OSError as e: + self._close_log_file() msg = f"File Error: {e}\nError while handling {self.slurm_log_file}" raise LauncherError(msg) except subprocess.SubprocessError as e: + self._close_log_file() msg = f"IO Error: {e}\nSee {self.job_spec.log_path}" raise LauncherError(msg) - finally: - self._close_process() def poll(self) -> JobStatus: """Check status of the running process. @@ -139,19 +141,24 @@ def kill(self) -> None: except subprocess.TimeoutExpired: self.process.kill() self._post_finish( - JobStatus.KILLED, ErrorMessage(line_number=None, message="Job killed!", context=[]) + JobStatus.KILLED, + ErrorMessage( + line_number=None, + message="Job killed!", + context=[], + ), ) def _post_finish(self, status, err_msg) -> None: + self._close_log_file() super()._post_finish(status, err_msg) - self._close_process() self.process = None - def _close_process(self) -> None: - """Close the file descriptors associated with the process.""" - assert self.process - if self.process.stdout: - self.process.stdout.close() + def _close_log_file(self) -> None: + """Close the log file if it is open.""" + if self._log_file: + self._log_file.close() + self._log_file = None @staticmethod def prepare_workspace(cfg: "WorkspaceConfig") -> None: