@@ -580,6 +580,7 @@ def get_attributions(
580580 n_rel = 5 ,
581581 device = "cpu" ,
582582 normalize = True ,
583+ remove_spot_attributes = False ,
583584) -> pd .DataFrame :
584585 """Get the attributions of a neural network.
585586
@@ -600,6 +601,9 @@ def get_attributions(
600601 The device to use. Defaults to "cpu".
601602 normalize (bool, optional):
602603 Whether to normalize the input data. Defaults to True.
604+ remove_spot_attributes (bool, optional):
605+ Whether to remove the spot attributes.
606+ If True, a torch model is created via `get_removed_attributes`. Defaults to False.
603607
604608 Returns:
605609 pd.DataFrame (object): A DataFrame with the attributions.
@@ -613,7 +617,10 @@ def get_attributions(
613617 config = get_tuned_architecture (spot_tuner , fun_control )
614618 train_model (config , fun_control , timestamp = False )
615619 model_loaded = load_light_from_checkpoint (config , fun_control , postfix = "_TRAIN" )
616- removed_attributes , model = get_removed_attributes_and_base_net (net = model_loaded )
620+ if remove_spot_attributes :
621+ removed_attributes , model = get_removed_attributes_and_base_net (net = model_loaded )
622+ else :
623+ model = model_loaded
617624 model = model .to (device )
618625 model .eval ()
619626 # get feature names
@@ -731,7 +738,9 @@ def is_square(n) -> bool:
731738 return n == int (math .sqrt (n )) ** 2
732739
733740
734- def get_layer_conductance (spot_tuner , fun_control , layer_idx , device = "cpu" , normalize = True ) -> np .ndarray :
741+ def get_layer_conductance (
742+ spot_tuner , fun_control , layer_idx , device = "cpu" , normalize = True , remove_spot_attributes = False
743+ ) -> np .ndarray :
735744 """
736745 Compute the average layer conductance attributions for a specified layer in the model.
737746
@@ -746,6 +755,8 @@ def get_layer_conductance(spot_tuner, fun_control, layer_idx, device="cpu", norm
746755 The device to use. Defaults to "cpu".
747756 normalize (bool, optional):
748757 Whether to normalize the input data. Defaults to True.
758+ remove_spot_attributes (bool, optional):
759+ Whether to remove the spot attributes. Defaults to False.
749760
750761 Returns:
751762 numpy.ndarray:
@@ -761,7 +772,10 @@ def get_layer_conductance(spot_tuner, fun_control, layer_idx, device="cpu", norm
761772 config = get_tuned_architecture (spot_tuner , fun_control )
762773 train_model (config , fun_control , timestamp = False )
763774 model_loaded = load_light_from_checkpoint (config , fun_control , postfix = "_TRAIN" )
764- removed_attributes , model = get_removed_attributes_and_base_net (net = model_loaded )
775+ if remove_spot_attributes :
776+ removed_attributes , model = get_removed_attributes_and_base_net (net = model_loaded )
777+ else :
778+ model = model_loaded
765779 model = model .to (device )
766780 model .eval ()
767781
@@ -794,7 +808,7 @@ def get_layer_conductance(spot_tuner, fun_control, layer_idx, device="cpu", norm
794808 return avg_layer_attributions
795809
796810
797- def get_weights_conductance_last_layer (spot_tuner , fun_control , device = "cpu" ) -> tuple :
811+ def get_weights_conductance_last_layer (spot_tuner , fun_control , device = "cpu" , remove_spot_attributes = False ) -> tuple :
798812 """
799813 Get the weights and the conductance of the last layer.
800814
@@ -805,11 +819,16 @@ def get_weights_conductance_last_layer(spot_tuner, fun_control, device="cpu") ->
805819 A dictionary with the function control.
806820 device (str, optional):
807821 The device to use. Defaults to "cpu".
822+ remove_spot_attributes (bool, optional):
823+ Whether to remove the spot attributes. Defaults to False.
808824 """
809825 config = get_tuned_architecture (spot_tuner , fun_control )
810826 train_model (config , fun_control , timestamp = False )
811827 model_loaded = load_light_from_checkpoint (config , fun_control , postfix = "_TRAIN" )
812- removed_attributes , model = get_removed_attributes_and_base_net (net = model_loaded )
828+ if remove_spot_attributes :
829+ removed_attributes , model = get_removed_attributes_and_base_net (net = model_loaded )
830+ else :
831+ model = model_loaded
813832 model = model .to (device )
814833 model .eval ()
815834
@@ -860,7 +879,7 @@ def plot_conductance_last_layer(weights_last, layer_conductance_last, figsize=(1
860879 plt .show ()
861880
862881
863- def get_all_layers_conductance (spot_tuner , fun_control , device = "cpu" ) -> dict :
882+ def get_all_layers_conductance (spot_tuner , fun_control , device = "cpu" , remove_spot_attributes = False ) -> dict :
864883 """
865884 Get the conductance of all layers.
866885
@@ -871,11 +890,16 @@ def get_all_layers_conductance(spot_tuner, fun_control, device="cpu") -> dict:
871890 A dictionary with the function control.
872891 device (str, optional):
873892 The device to use. Defaults to "cpu".
893+ remove_spot_attributes (bool, optional):
894+ Whether to remove the spot attributes. Defaults to False.
874895 """
875896 config = get_tuned_architecture (spot_tuner , fun_control )
876897 train_model (config , fun_control , timestamp = False )
877898 model_loaded = load_light_from_checkpoint (config , fun_control , postfix = "_TRAIN" )
878- removed_attributes , model = get_removed_attributes_and_base_net (net = model_loaded )
899+ if remove_spot_attributes :
900+ removed_attributes , model = get_removed_attributes_and_base_net (net = model_loaded )
901+ else :
902+ model = model_loaded
879903 model = model .to (device )
880904 model .eval ()
881905 _ , index = get_weights (model , return_index = True )
0 commit comments