Skip to content

Commit efd0313

Browse files
0.16.2
xai non-torch
1 parent 226efde commit efd0313

3 files changed

Lines changed: 39 additions & 8 deletions

File tree

MANIFEST.in

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,10 @@ include src/spotpython/data/*.csv
22
include src/spotpython/data/*.json
33
include src/spotpython/data/*.pkl
44
include src/spotpython/hyperdict/*.json
5+
6+
prune docs
7+
prune Figures.d
8+
prune lightning_logs
9+
prune notebooks
10+
prune runs
11+
prune runs_OLD

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.16.1"
10+
version = "0.16.2"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/plot/xai.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,7 @@ def get_attributions(
580580
n_rel=5,
581581
device="cpu",
582582
normalize=True,
583+
remove_spot_attributes=False,
583584
) -> pd.DataFrame:
584585
"""Get the attributions of a neural network.
585586
@@ -600,6 +601,9 @@ def get_attributions(
600601
The device to use. Defaults to "cpu".
601602
normalize (bool, optional):
602603
Whether to normalize the input data. Defaults to True.
604+
remove_spot_attributes (bool, optional):
605+
Whether to remove the spot attributes.
606+
If True, a torch model is created via `get_removed_attributes`. Defaults to False.
603607
604608
Returns:
605609
pd.DataFrame (object): A DataFrame with the attributions.
@@ -613,7 +617,10 @@ def get_attributions(
613617
config = get_tuned_architecture(spot_tuner, fun_control)
614618
train_model(config, fun_control, timestamp=False)
615619
model_loaded = load_light_from_checkpoint(config, fun_control, postfix="_TRAIN")
616-
removed_attributes, model = get_removed_attributes_and_base_net(net=model_loaded)
620+
if remove_spot_attributes:
621+
removed_attributes, model = get_removed_attributes_and_base_net(net=model_loaded)
622+
else:
623+
model = model_loaded
617624
model = model.to(device)
618625
model.eval()
619626
# get feature names
@@ -731,7 +738,9 @@ def is_square(n) -> bool:
731738
return n == int(math.sqrt(n)) ** 2
732739

733740

734-
def get_layer_conductance(spot_tuner, fun_control, layer_idx, device="cpu", normalize=True) -> np.ndarray:
741+
def get_layer_conductance(
742+
spot_tuner, fun_control, layer_idx, device="cpu", normalize=True, remove_spot_attributes=False
743+
) -> np.ndarray:
735744
"""
736745
Compute the average layer conductance attributions for a specified layer in the model.
737746
@@ -746,6 +755,8 @@ def get_layer_conductance(spot_tuner, fun_control, layer_idx, device="cpu", norm
746755
The device to use. Defaults to "cpu".
747756
normalize (bool, optional):
748757
Whether to normalize the input data. Defaults to True.
758+
remove_spot_attributes (bool, optional):
759+
Whether to remove the spot attributes. Defaults to False.
749760
750761
Returns:
751762
numpy.ndarray:
@@ -761,7 +772,10 @@ def get_layer_conductance(spot_tuner, fun_control, layer_idx, device="cpu", norm
761772
config = get_tuned_architecture(spot_tuner, fun_control)
762773
train_model(config, fun_control, timestamp=False)
763774
model_loaded = load_light_from_checkpoint(config, fun_control, postfix="_TRAIN")
764-
removed_attributes, model = get_removed_attributes_and_base_net(net=model_loaded)
775+
if remove_spot_attributes:
776+
removed_attributes, model = get_removed_attributes_and_base_net(net=model_loaded)
777+
else:
778+
model = model_loaded
765779
model = model.to(device)
766780
model.eval()
767781

@@ -794,7 +808,7 @@ def get_layer_conductance(spot_tuner, fun_control, layer_idx, device="cpu", norm
794808
return avg_layer_attributions
795809

796810

797-
def get_weights_conductance_last_layer(spot_tuner, fun_control, device="cpu") -> tuple:
811+
def get_weights_conductance_last_layer(spot_tuner, fun_control, device="cpu", remove_spot_attributes=False) -> tuple:
798812
"""
799813
Get the weights and the conductance of the last layer.
800814
@@ -805,11 +819,16 @@ def get_weights_conductance_last_layer(spot_tuner, fun_control, device="cpu") ->
805819
A dictionary with the function control.
806820
device (str, optional):
807821
The device to use. Defaults to "cpu".
822+
remove_spot_attributes (bool, optional):
823+
Whether to remove the spot attributes. Defaults to False.
808824
"""
809825
config = get_tuned_architecture(spot_tuner, fun_control)
810826
train_model(config, fun_control, timestamp=False)
811827
model_loaded = load_light_from_checkpoint(config, fun_control, postfix="_TRAIN")
812-
removed_attributes, model = get_removed_attributes_and_base_net(net=model_loaded)
828+
if remove_spot_attributes:
829+
removed_attributes, model = get_removed_attributes_and_base_net(net=model_loaded)
830+
else:
831+
model = model_loaded
813832
model = model.to(device)
814833
model.eval()
815834

@@ -860,7 +879,7 @@ def plot_conductance_last_layer(weights_last, layer_conductance_last, figsize=(1
860879
plt.show()
861880

862881

863-
def get_all_layers_conductance(spot_tuner, fun_control, device="cpu") -> dict:
882+
def get_all_layers_conductance(spot_tuner, fun_control, device="cpu", remove_spot_attributes=False) -> dict:
864883
"""
865884
Get the conductance of all layers.
866885
@@ -871,11 +890,16 @@ def get_all_layers_conductance(spot_tuner, fun_control, device="cpu") -> dict:
871890
A dictionary with the function control.
872891
device (str, optional):
873892
The device to use. Defaults to "cpu".
893+
remove_spot_attributes (bool, optional):
894+
Whether to remove the spot attributes. Defaults to False.
874895
"""
875896
config = get_tuned_architecture(spot_tuner, fun_control)
876897
train_model(config, fun_control, timestamp=False)
877898
model_loaded = load_light_from_checkpoint(config, fun_control, postfix="_TRAIN")
878-
removed_attributes, model = get_removed_attributes_and_base_net(net=model_loaded)
899+
if remove_spot_attributes:
900+
removed_attributes, model = get_removed_attributes_and_base_net(net=model_loaded)
901+
else:
902+
model = model_loaded
879903
model = model.to(device)
880904
model.eval()
881905
_, index = get_weights(model, return_index=True)

0 commit comments

Comments
 (0)