Skip to content

Commit 797d7c5

Browse files
0.15.20
- xai.py determines shape of the dataset also for tensors - save_experiment in spot uses copy if deep.copy fails
1 parent f599f84 commit 797d7c5

3 files changed

Lines changed: 10 additions & 4 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.15.19"
10+
version = "0.15.20"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/plot/xai.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,10 @@ def get_attributions(
618618
model = model.to("cpu")
619619
model.eval()
620620
dataset = fun_control["data_set"]
621-
n_features = dataset.data.shape[1]
621+
try:
622+
n_features = dataset.data.shape[1]
623+
except AttributeError:
624+
n_features = dataset.tensors[0].shape[1]
622625
if feature_names is None:
623626
feature_names = [f"x{i}" for i in range(n_features)]
624627
batch_size = config["batch_size"]
@@ -745,7 +748,10 @@ def get_layer_conductance(spot_tuner, fun_control, layer_idx) -> np.ndarray:
745748
model.eval()
746749

747750
dataset = fun_control["data_set"]
748-
n_features = dataset.data.shape[1]
751+
try:
752+
n_features = dataset.data.shape[1]
753+
except AttributeError:
754+
n_features = dataset.tensors[0].shape[1]
749755
if feature_names is None:
750756
feature_names = [f"x{i}" for i in range(n_features)]
751757
batch_size = config["batch_size"]

src/spotpython/spot/spot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2314,7 +2314,7 @@ def save_experiment(self, filename=None) -> None:
23142314
except Exception as e:
23152315
logger.warning("Warning: Could not copy/save spot_tuner object!")
23162316
logger.warning(f"Error: {e}")
2317-
spot_tuner = None
2317+
spot_tuner = copy(self)
23182318
experiment = {
23192319
"design_control": design_control,
23202320
"fun_control": fun_control,

0 commit comments

Comments
 (0)