Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion kwave/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,13 @@ def options_to_kwargs(simulation_options=None, execution_options=None):
kwargs["use_kspace"] = opts.use_kspace
kwargs["smooth_p0"] = opts.smooth_p0
if opts.data_path is not None:
kwargs["data_path"] = opts.data_path
import os
from tempfile import gettempdir

normalized_data_path = os.path.realpath(os.path.normpath(os.fspath(opts.data_path)))
normalized_tempdir = os.path.realpath(os.path.normpath(os.fspath(gettempdir())))
if normalized_data_path != normalized_tempdir:
kwargs["data_path"] = opts.data_path
Comment thread
aconesac marked this conversation as resolved.
if opts.save_to_disk_exit:
kwargs["save_only"] = True

Expand Down
17 changes: 16 additions & 1 deletion kwave/solvers/cpp_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Handles HDF5 serialization and C++ binary execution without
depending on kWaveSimulation or the legacy options classes.
"""

import os
import shutil
import stat
Expand Down Expand Up @@ -46,6 +47,11 @@ def prepare(self, data_path=None):
input_file = os.path.join(data_path, "kwave_input.h5")
output_file = os.path.join(data_path, "kwave_output.h5")

if os.path.exists(input_file):
raise FileExistsError(
f"{input_file!r} already exists. Delete it or choose a different data_path to avoid overwriting previous simulation inputs."
)
Comment thread
waltsims marked this conversation as resolved.

self._write_hdf5(input_file)
return input_file, output_file

Expand All @@ -56,7 +62,16 @@ def run(self, *, device="cpu", num_threads=None, device_num=None, quiet=False, d
input_file, output_file = self.prepare(data_path=data_path)
data_dir = os.path.dirname(input_file)
try:
self._execute(input_file, output_file, device=device, num_threads=num_threads, device_num=device_num, quiet=quiet, debug=debug, binary_path=binary_path)
self._execute(
input_file,
output_file,
device=device,
num_threads=num_threads,
device_num=device_num,
quiet=quiet,
debug=debug,
binary_path=binary_path,
)
result = self._parse_output(output_file)
result = self._fix_output_order(result)
return result
Expand Down
12 changes: 12 additions & 0 deletions tests/test_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,15 @@ def test_both_options(self):
def test_none_options(self):
kwargs = options_to_kwargs()
assert kwargs == {}

def test_default_data_path_not_forwarded(self):
# Before fix: data_path defaulted to gettempdir() and was always forwarded,
# so every run targeted /tmp/kwave_input.h5 and crashed on the second call.
kwargs = options_to_kwargs(simulation_options=SimulationOptions())
assert "data_path" not in kwargs

def test_custom_data_path_is_forwarded(self, tmp_path):
opts = SimulationOptions()
opts.data_path = str(tmp_path)
kwargs = options_to_kwargs(simulation_options=opts)
assert kwargs["data_path"] == str(tmp_path)
10 changes: 10 additions & 0 deletions tests/test_cpp_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@
from kwave.solvers.cpp_simulation import CppSimulation


class TestPrepare:
def test_raises_when_input_already_exists(self, tmp_path):
# Before fix: h5py raised a cryptic ValueError('name already exists')
# instead of a clear FileExistsError.
(tmp_path / "kwave_input.h5").write_bytes(b"")
sim = CppSimulation.__new__(CppSimulation)
with pytest.raises(FileExistsError, match="already exists"):
sim.prepare(data_path=str(tmp_path))


class TestResolveBinaryPath:
"""Tests for CppSimulation._resolve_binary_path().

Expand Down
Loading