1515from captum .attr import LayerConductance , LayerActivation , LayerIntegratedGradients
1616from captum .attr import IntegratedGradients , DeepLift , GradientShap , NoiseTunnel , FeatureAblation
1717from matplotlib .ticker import MaxNLocator
18+ from spotpython .data .lightdatamodule import LightDataModule
1819
1920
2021def get_activations (net , fun_control , batch_size , device = "cpu" ) -> dict :
@@ -72,7 +73,7 @@ def get_activations(net, fun_control, batch_size, device="cpu") -> dict:
7273 """
7374 activations = {}
7475 net .eval ()
75- print (f"net: { net } " )
76+ # print(f"net: {net}")
7677 dataset = fun_control ["data_set" ]
7778 dataloader = DataLoader (dataset , batch_size = batch_size , shuffle = False )
7879 inputs , _ = next (iter (dataloader ))
@@ -549,7 +550,7 @@ def visualize_weights(net, absolute=True, cmap="gray", figsize=(6, 6)) -> None:
549550 plot_nn_values_scatter (nn_values = weights , nn_values_names = "Weights" , absolute = absolute , cmap = cmap , figsize = figsize )
550551
551552
552- def visualize_gradients (net , fun_control , batch_size , absolute = True , cmap = "gray" , figsize = (6 , 6 )) -> None :
553+ def visualize_gradients (net , fun_control , batch_size , absolute = True , cmap = "gray" , figsize = (6 , 6 ), device = "cpu" ) -> None :
553554 """
554555 Scatter plots the gradients of a neural network.
555556
@@ -566,6 +567,8 @@ def visualize_gradients(net, fun_control, batch_size, absolute=True, cmap="gray"
566567 The colormap to use. Defaults to "gray".
567568 figsize (tuple, optional):
568569 The figure size. Defaults to (6, 6).
570+ device (str, optional):
571+ The device to use. Defaults to "cpu".
569572
570573 Returns:
571574 None
@@ -574,6 +577,7 @@ def visualize_gradients(net, fun_control, batch_size, absolute=True, cmap="gray"
574577 net ,
575578 fun_control ,
576579 batch_size = batch_size ,
580+ device = device ,
577581 )
578582 plot_nn_values_scatter (nn_values = grads , nn_values_names = "Gradients" , absolute = absolute , cmap = cmap , figsize = figsize )
579583
@@ -585,6 +589,7 @@ def get_attributions(
585589 baseline = None ,
586590 abs_attr = True ,
587591 n_rel = 5 ,
592+ device = "cpu" ,
588593) -> pd .DataFrame :
589594 """Get the attributions of a neural network.
590595
@@ -601,6 +606,8 @@ def get_attributions(
601606 Whether the method should sort by the absolute attribution values. Defaults to True.
602607 n_rel (int, optional):
603608 The number of relevant features. Defaults to 5.
609+ device (str, optional):
610+ The device to use. Defaults to "cpu".
604611
605612 Returns:
606613 pd.DataFrame (object): A DataFrame with the attributions.
@@ -615,7 +622,121 @@ def get_attributions(
615622 train_model (config , fun_control , timestamp = False )
616623 model_loaded = load_light_from_checkpoint (config , fun_control , postfix = "_TRAIN" )
617624 removed_attributes , model = get_removed_attributes_and_base_net (net = model_loaded )
618- model = model .to ("cpu" )
625+ model = model .to (device )
626+ model .eval ()
627+ # get feature names
628+ dataset = fun_control ["data_set" ]
629+ try :
630+ n_features = dataset .data .shape [1 ]
631+ except AttributeError :
632+ n_features = dataset .tensors [0 ].shape [1 ]
633+ if feature_names is None :
634+ feature_names = [f"x{ i } " for i in range (n_features )]
635+ # get batch size
636+ batch_size = config ["batch_size" ]
637+ # test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
638+
639+ data_module = LightDataModule (
640+ dataset = dataset ,
641+ batch_size = batch_size ,
642+ test_size = fun_control ["test_size" ],
643+ scaler = fun_control ["scaler" ],
644+ verbosity = 10
645+ )
646+ data_module .setup (stage = "test" )
647+ test_loader = data_module .test_dataloader ()
648+
649+ if attr_method == "IntegratedGradients" :
650+ attr = IntegratedGradients (model )
651+ elif attr_method == "DeepLift" :
652+ attr = DeepLift (model )
653+ elif attr_method == "GradientShap" : # Todo: would need a baseline
654+ if baseline is None :
655+ raise ValueError ("baseline cannot be 'None' for GradientShap" )
656+ attr = GradientShap (model )
657+ elif attr_method == "FeatureAblation" :
658+ attr = FeatureAblation (model )
659+ else :
660+ raise ValueError (
661+ """
662+ Unsupported attribution method.
663+ Please choose from 'IntegratedGradients', 'DeepLift', 'GradientShap', or 'FeatureAblation'.
664+ """
665+ )
666+ for inputs , _ in test_loader :
667+ inputs .requires_grad_ ()
668+ attributions = attr .attribute (inputs , return_convergence_delta = False , baselines = baseline )
669+ if total_attributions is None :
670+ total_attributions = attributions
671+ else :
672+ if len (attributions ) == len (total_attributions ):
673+ total_attributions += attributions
674+
675+ # Calculation of average attribution across all batches
676+ avg_attributions = total_attributions .mean (dim = 0 ).detach ().numpy ()
677+
678+ # Transformation to the absolute attribution values if abs_attr is True
679+ # Get indices of the n most important features
680+ if abs_attr is True :
681+ abs_avg_attributions = abs (avg_attributions )
682+ top_n_indices = abs_avg_attributions .argsort ()[- n_rel :][::- 1 ]
683+ else :
684+ top_n_indices = avg_attributions .argsort ()[- n_rel :][::- 1 ]
685+
686+ # Get the importance values for the top n features
687+ top_n_importances = avg_attributions [top_n_indices ]
688+
689+ df = pd .DataFrame (
690+ {
691+ "Feature Index" : top_n_indices ,
692+ "Feature" : [feature_names [i ] for i in top_n_indices ],
693+ attr_method + "Attribution" : top_n_importances ,
694+ }
695+ )
696+ return df
697+
698+
699+ def get_attributions_old (
700+ spot_tuner ,
701+ fun_control ,
702+ attr_method = "IntegratedGradients" ,
703+ baseline = None ,
704+ abs_attr = True ,
705+ n_rel = 5 ,
706+ device = "cpu" ,
707+ ) -> pd .DataFrame :
708+ """Get the attributions of a neural network.
709+
710+ Args:
711+ spot_tuner (object):
712+ The spot tuner object.
713+ fun_control (dict):
714+ A dictionary with the function control.
715+ attr_method (str, optional):
716+ The attribution method. Defaults to "IntegratedGradients".
717+ baseline (torch.Tensor, optional):
718+ The baseline for the attribution methods. Defaults to None.
719+ abs_attr (bool, optional):
720+ Whether the method should sort by the absolute attribution values. Defaults to True.
721+ n_rel (int, optional):
722+ The number of relevant features. Defaults to 5.
723+ device (str, optional):
724+ The device to use. Defaults to "cpu".
725+
726+ Returns:
727+ pd.DataFrame (object): A DataFrame with the attributions.
728+ """
729+ try :
730+ fun_control ["data_set" ].names
731+ except AttributeError :
732+ fun_control ["data_set" ].names = None
733+ feature_names = fun_control ["data_set" ].names
734+ total_attributions = None
735+ config = get_tuned_architecture (spot_tuner , fun_control )
736+ train_model (config , fun_control , timestamp = False )
737+ model_loaded = load_light_from_checkpoint (config , fun_control , postfix = "_TRAIN" )
738+ removed_attributes , model = get_removed_attributes_and_base_net (net = model_loaded )
739+ model = model .to (device )
619740 model .eval ()
620741 dataset = fun_control ["data_set" ]
621742 try :
@@ -626,7 +747,7 @@ def get_attributions(
626747 feature_names = [f"x{ i } " for i in range (n_features )]
627748 batch_size = config ["batch_size" ]
628749 # train_loader = DataLoader(dataset, batch_size=batch_size)
629- test_loader = DataLoader (dataset , batch_size = batch_size )
750+ test_loader = DataLoader (dataset , batch_size = batch_size , shuffle = False )
630751 if attr_method == "IntegratedGradients" :
631752 attr = IntegratedGradients (model )
632753 elif attr_method == "DeepLift" :
@@ -717,7 +838,7 @@ def is_square(n) -> bool:
717838 return n == int (math .sqrt (n )) ** 2
718839
719840
720- def get_layer_conductance (spot_tuner , fun_control , layer_idx ) -> np .ndarray :
841+ def get_layer_conductance (spot_tuner , fun_control , layer_idx , device = "cpu" ) -> np .ndarray :
721842 """
722843 Compute the average layer conductance attributions for a specified layer in the model.
723844
@@ -728,6 +849,8 @@ def get_layer_conductance(spot_tuner, fun_control, layer_idx) -> np.ndarray:
728849 The fun_control dictionary containing the hyperparameters used to train the model.
729850 layer_idx (int):
730851 Index of the layer for which to compute layer conductance attributions.
852+ device (str, optional):
853+ The device to use. Defaults to "cpu".
731854
732855 Returns:
733856 numpy.ndarray:
@@ -744,7 +867,7 @@ def get_layer_conductance(spot_tuner, fun_control, layer_idx) -> np.ndarray:
744867 train_model (config , fun_control , timestamp = False )
745868 model_loaded = load_light_from_checkpoint (config , fun_control , postfix = "_TRAIN" )
746869 removed_attributes , model = get_removed_attributes_and_base_net (net = model_loaded )
747- model = model .to ("cpu" )
870+ model = model .to (device )
748871 model .eval ()
749872
750873 dataset = fun_control ["data_set" ]
@@ -776,15 +899,23 @@ def get_layer_conductance(spot_tuner, fun_control, layer_idx) -> np.ndarray:
776899 return avg_layer_attributions
777900
778901
779- def get_weights_conductance_last_layer (spot_tuner , fun_control ) :
902+ def get_weights_conductance_last_layer (spot_tuner , fun_control , device = "cpu" ) -> tuple :
780903 """
781904 Get the weights and the conductance of the last layer.
905+
906+ Args:
907+ spot_tuner (object):
908+ The spot tuner object.
909+ fun_control (dict):
910+ A dictionary with the function control.
911+ device (str, optional):
912+ The device to use. Defaults to "cpu".
782913 """
783914 config = get_tuned_architecture (spot_tuner , fun_control )
784915 train_model (config , fun_control , timestamp = False )
785916 model_loaded = load_light_from_checkpoint (config , fun_control , postfix = "_TRAIN" )
786917 removed_attributes , model = get_removed_attributes_and_base_net (net = model_loaded )
787- model = model .to ("cpu" )
918+ model = model .to (device )
788919 model .eval ()
789920
790921 weights , index = get_weights (model , return_index = True )
@@ -796,11 +927,28 @@ def get_weights_conductance_last_layer(spot_tuner, fun_control):
796927 return weights_last , layer_conductance_last
797928
798929
799- def plot_conductance_last_layer (weights_last , layer_conductance_last , show = True ):
930+ def plot_conductance_last_layer (weights_last , layer_conductance_last , figsize = ( 12 , 6 ), show = True ) -> None :
800931 """
801932 Plot the conductance of the last layer.
933+
934+ Args:
935+ weights_last (np.ndarray):
936+ The weights of the last layer.
937+ layer_conductance_last (np.ndarray):
938+ The conductance of the last layer.
939+ figsize (tuple, optional):
940+ The figure size. Defaults to (12, 6).
941+ show (bool, optional):
942+ Whether to show the plot. Defaults
943+
944+ Examples:
945+ >>> import numpy as np
946+ from spotpython.plot.xai import plot_conductance_last_layer
947+ weights_last = np.random.rand(10)
948+ layer_conductance_last = np.random.rand(10)
949+ plot_conductance_last_layer(weights_last, layer_conductance_last, show=True)
802950 """
803- fig , ax = plt .subplots (figsize = ( 12 , 6 ) )
951+ fig , ax = plt .subplots (figsize = figsize )
804952 ax .bar (range (len (weights_last )), weights_last / weights_last .max (), label = "Weights" , alpha = 0.5 )
805953 ax .bar (
806954 range (len (layer_conductance_last )),
@@ -817,15 +965,55 @@ def plot_conductance_last_layer(weights_last, layer_conductance_last, show=True)
817965 plt .show ()
818966
819967
820- def get_all_layers_conductance (spot_tuner , fun_control ):
968+ def get_all_layers_conductance (spot_tuner , fun_control , device = "cpu" ) -> dict :
969+ """
970+ Get the conductance of all layers.
971+
972+ Args:
973+ spot_tuner (object):
974+ The spot tuner object.
975+ fun_control (dict):
976+ A dictionary with the function control.
977+ device (str, optional):
978+ The device to use. Defaults to "cpu".
979+ """
821980 config = get_tuned_architecture (spot_tuner , fun_control )
822981 train_model (config , fun_control , timestamp = False )
823982 model_loaded = load_light_from_checkpoint (config , fun_control , postfix = "_TRAIN" )
824983 removed_attributes , model = get_removed_attributes_and_base_net (net = model_loaded )
825- model = model .to ("cpu" )
984+ model = model .to (device )
826985 model .eval ()
827986 _ , index = get_weights (model , return_index = True )
828987 layer_conductance = {}
829988 for i in index :
830989 layer_conductance [i ] = get_layer_conductance (spot_tuner , fun_control , layer_idx = i )
831990 return layer_conductance
991+
992+
993+ def sort_layers (data_dict ) -> dict :
994+ """
995+ Sorts a dictionary with keys in the format "Layer X" based on the numerical value X.
996+
997+ Args:
998+ data_dict (dict): A dictionary with keys in the format "Layer X".
999+
1000+ Returns:
1001+ dict: A dictionary with the keys sorted based on the numerical value X.
1002+
1003+ Examples:
1004+ >>> data_dict = {
1005+ ... "Layer 1": [1, 2, 3],
1006+ ... "Layer 3": [4, 5, 6],
1007+ ... "Layer 2": [7, 8, 9]
1008+ ... }
1009+ >>> sort_layers(data_dict)
1010+ {'Layer 1': [1, 2, 3], 'Layer 2': [7, 8, 9], 'Layer 3': [4,
1011+
1012+ """
1013+ # Use a lambda function to extract the number X from "Layer X" and sort based on that number
1014+ sorted_items = sorted (data_dict .items (), key = lambda item : int (item [0 ].split ()[1 ]))
1015+
1016+ # Create a new dictionary from the sorted items
1017+ sorted_dict = dict (sorted_items )
1018+
1019+ return sorted_dict
0 commit comments