1414import pandas as pd
1515from captum .attr import LayerConductance , LayerActivation , LayerIntegratedGradients
1616from captum .attr import IntegratedGradients , DeepLift , GradientShap , NoiseTunnel , FeatureAblation
17+ from matplotlib .ticker import MaxNLocator
1718
1819
1920def get_activations (net , fun_control , batch_size , device = "cpu" ) -> dict :
@@ -87,16 +88,21 @@ def get_activations(net, fun_control, batch_size, device="cpu") -> dict:
8788 return activations
8889
8990
90- def get_weights (net ) -> dict :
91+ def get_weights (net , return_index = False ) -> dict :
9192 """
9293 Get the weights of a neural network.
9394
9495 Args:
9596 net (object):
9697 A neural network.
98+ return_index (bool, optional):
99+ Whether to return the index. Defaults to False.
97100
98101 Returns:
99- dict: A dictionary with the weights of the neural network.
102+ dict:
103+ A dictionary with the weights of the neural network.
104+ index (list):
105+ The layer index list.
100106
101107 Examples:
102108 >>> from torch.utils.data import DataLoader
@@ -150,13 +156,19 @@ def get_weights(net) -> dict:
150156
151157 """
152158 weights = {}
159+ index = []
153160 for name , param in net .named_parameters ():
154161 if name .endswith (".bias" ):
155162 continue
163+ # add (int(name.split(".")[1])) to the index list
164+ index .append (int (name .split ("." )[1 ]))
156165 key_name = f"Layer { name .split ('.' )[1 ]} "
157166 weights [key_name ] = param .detach ().view (- 1 ).cpu ().numpy ()
158167 # print(f"weights: {weights}")
159- return weights
168+ if return_index :
169+ return weights , index
170+ else :
171+ return weights
160172
161173
162174def get_gradients (net , fun_control , batch_size , device = "cpu" ) -> dict :
@@ -573,7 +585,6 @@ def get_attributions(
573585 baseline = None ,
574586 abs_attr = True ,
575587 n_rel = 5 ,
576- feature_names = None ,
577588):
578589 """Get the attributions of a neural network.
579590
@@ -590,12 +601,15 @@ def get_attributions(
590601 Whether the method should sort by the absolute attribution values. Defaults to True.
591602 n_rel (int, optional):
592603 The number of relevant features. Defaults to 5.
593- feature_names (list, optional):
594- The feature names. Defaults to None.
595604
596605 Returns:
597606 pd.DataFrame: A DataFrame with the attributions.
598607 """
608+ try :
609+ fun_control ["data_set" ].names
610+ except AttributeError :
611+ fun_control ["data_set" ].names = None
612+ feature_names = fun_control ["data_set" ].names
599613 total_attributions = None
600614 config = get_tuned_architecture (spot_tuner , fun_control )
601615 train_model (config , fun_control , timestamp = False )
@@ -698,3 +712,114 @@ def is_square(n):
698712 False
699713 """
700714 return n == int (math .sqrt (n )) ** 2
715+
716+
717+ def get_layer_conductance (spot_tuner , fun_control , layer_idx ):
718+ """
719+ Compute the average layer conductance attributions for a specified layer in the model.
720+
721+ Args:
722+ spot_tuner (spot.Spot):
723+ The spot tuner object containing the trained model.
724+ fun_control (dict):
725+ The fun_control dictionary containing the hyperparameters used to train the model.
726+ layer_idx (int):
727+ Index of the layer for which to compute layer conductance attributions.
728+
729+ Returns:
730+ numpy.ndarray:
731+ An array containing the average layer conductance attributions for the specified layer.
732+ The shape of the array corresponds to the shape of the attributions.
733+ """
734+ try :
735+ fun_control ["data_set" ].names
736+ except AttributeError :
737+ fun_control ["data_set" ].names = None
738+ feature_names = fun_control ["data_set" ].names
739+
740+ config = get_tuned_architecture (spot_tuner , fun_control )
741+ train_model (config , fun_control , timestamp = False )
742+ model_loaded = load_light_from_checkpoint (config , fun_control , postfix = "_TRAIN" )
743+ removed_attributes , model = get_removed_attributes_and_base_net (net = model_loaded )
744+ model = model .to ("cpu" )
745+ model .eval ()
746+
747+ dataset = fun_control ["data_set" ]
748+ n_features = dataset .data .shape [1 ]
749+ if feature_names is None :
750+ feature_names = [f"x{ i } " for i in range (n_features )]
751+ batch_size = config ["batch_size" ]
752+ # train_loader = DataLoader(dataset, batch_size=batch_size)
753+ test_loader = DataLoader (dataset , batch_size = batch_size )
754+
755+ total_layer_attributions = None
756+ layers = model .layers
757+ print ("Conductance analysis for layer: " , layers [layer_idx ])
758+ lc = LayerConductance (model , layers [layer_idx ])
759+
760+ for inputs , labels in test_loader :
761+ lc_attr_test = lc .attribute (inputs , n_steps = 10 , attribute_to_layer_input = True )
762+ if total_layer_attributions is None :
763+ total_layer_attributions = lc_attr_test
764+ else :
765+ if len (lc_attr_test ) == len (total_layer_attributions ):
766+ total_layer_attributions += lc_attr_test
767+
768+ avg_layer_attributions = total_layer_attributions .mean (dim = 0 ).detach ().numpy ()
769+
770+ return avg_layer_attributions
771+
772+
773+ def get_weights_conductance_last_layer (spot_tuner , fun_control ):
774+ """
775+ Get the weights and the conductance of the last layer.
776+ """
777+ config = get_tuned_architecture (spot_tuner , fun_control )
778+ train_model (config , fun_control , timestamp = False )
779+ model_loaded = load_light_from_checkpoint (config , fun_control , postfix = "_TRAIN" )
780+ removed_attributes , model = get_removed_attributes_and_base_net (net = model_loaded )
781+ model = model .to ("cpu" )
782+ model .eval ()
783+
784+ weights , index = get_weights (model , return_index = True )
785+ layer_idx = index [- 1 ]
786+ weights_last = weights [f"Layer { layer_idx } " ]
787+ weights_last
788+ layer_conductance_last = get_layer_conductance (spot_tuner , fun_control , layer_idx = layer_idx )
789+
790+ return weights_last , layer_conductance_last
791+
792+
793+ def plot_conductance_last_layer (weights_last , layer_conductance_last , show = True ):
794+ """
795+ Plot the conductance of the last layer.
796+ """
797+ fig , ax = plt .subplots (figsize = (12 , 6 ))
798+ ax .bar (range (len (weights_last )), weights_last / weights_last .max (), label = "Weights" , alpha = 0.5 )
799+ ax .bar (
800+ range (len (layer_conductance_last )),
801+ layer_conductance_last / layer_conductance_last .max (),
802+ label = "Layer Conductance" ,
803+ alpha = 0.5 ,
804+ )
805+ ax .set_xlabel ("Weight Index" )
806+ ax .set_ylabel ("Normalized Value" )
807+ ax .set_title ("Layer Conductance vs. Weights" )
808+ ax .legend ()
809+ ax .xaxis .set_major_locator (MaxNLocator (integer = True ))
810+ if show :
811+ plt .show ()
812+
813+
814+ def get_all_layers_conductance (spot_tuner , fun_control ):
815+ config = get_tuned_architecture (spot_tuner , fun_control )
816+ train_model (config , fun_control , timestamp = False )
817+ model_loaded = load_light_from_checkpoint (config , fun_control , postfix = "_TRAIN" )
818+ removed_attributes , model = get_removed_attributes_and_base_net (net = model_loaded )
819+ model = model .to ("cpu" )
820+ model .eval ()
821+ _ , index = get_weights (model , return_index = True )
822+ layer_conductance = {}
823+ for i in index :
824+ layer_conductance [i ] = get_layer_conductance (spot_tuner , fun_control , layer_idx = i )
825+ return layer_conductance
0 commit comments