@@ -6798,7 +6798,33 @@ <h2 id="spotpython.light.trainmodel.train_model_xai" class="doc doc-heading">
67986798< span class ="normal "> 751</ span >
67996799< span class ="normal "> 752</ span >
68006800< span class ="normal "> 753</ span >
6801- < span class ="normal "> 754</ span > </ pre > </ div > </ td > < td class ="code "> < div > < pre > < span > </ span > < code > < span class ="k "> def</ span > < span class ="nf "> train_model_xai</ span > < span class ="p "> (</ span > < span class ="n "> config</ span > < span class ="p "> :</ span > < span class ="nb "> dict</ span > < span class ="p "> ,</ span > < span class ="n "> fun_control</ span > < span class ="p "> :</ span > < span class ="nb "> dict</ span > < span class ="p "> ,</ span > < span class ="n "> timestamp</ span > < span class ="p "> :</ span > < span class ="nb "> bool</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="nb "> float</ span > < span class ="p "> :</ span >
6801+ < span class ="normal "> 754</ span >
6802+ < span class ="normal "> 755</ span >
6803+ < span class ="normal "> 756</ span >
6804+ < span class ="normal "> 757</ span >
6805+ < span class ="normal "> 758</ span >
6806+ < span class ="normal "> 759</ span >
6807+ < span class ="normal "> 760</ span >
6808+ < span class ="normal "> 761</ span >
6809+ < span class ="normal "> 762</ span >
6810+ < span class ="normal "> 763</ span >
6811+ < span class ="normal "> 764</ span >
6812+ < span class ="normal "> 765</ span >
6813+ < span class ="normal "> 766</ span >
6814+ < span class ="normal "> 767</ span >
6815+ < span class ="normal "> 768</ span >
6816+ < span class ="normal "> 769</ span >
6817+ < span class ="normal "> 770</ span >
6818+ < span class ="normal "> 771</ span >
6819+ < span class ="normal "> 772</ span >
6820+ < span class ="normal "> 773</ span >
6821+ < span class ="normal "> 774</ span >
6822+ < span class ="normal "> 775</ span >
6823+ < span class ="normal "> 776</ span >
6824+ < span class ="normal "> 777</ span >
6825+ < span class ="normal "> 778</ span >
6826+ < span class ="normal "> 779</ span >
6827+ < span class ="normal "> 780</ span > </ pre > </ div > </ td > < td class ="code "> < div > < pre > < span > </ span > < code > < span class ="k "> def</ span > < span class ="nf "> train_model_xai</ span > < span class ="p "> (</ span > < span class ="n "> config</ span > < span class ="p "> :</ span > < span class ="nb "> dict</ span > < span class ="p "> ,</ span > < span class ="n "> fun_control</ span > < span class ="p "> :</ span > < span class ="nb "> dict</ span > < span class ="p "> ,</ span > < span class ="n "> timestamp</ span > < span class ="p "> :</ span > < span class ="nb "> bool</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="nb "> float</ span > < span class ="p "> :</ span >
68026828< span class ="w "> </ span > < span class ="sd "> """</ span >
68036829< span class ="sd "> Trains a model using the given configuration and function control parameters. Performs feature attribution analysis and calculates consistency of these methods.</ span >
68046830
@@ -7131,7 +7157,7 @@ <h2 id="spotpython.light.trainmodel.train_model_xai" class="doc doc-heading">
71317157 < span class ="n "> attributions_list</ span > < span class ="o "> =</ span > < span class ="p "> [</ span > < span class ="n "> attributions_dict</ span > < span class ="p "> [</ span > < span class ="n "> method</ span > < span class ="p "> ]</ span > < span class ="k "> for</ span > < span class ="n "> method</ span > < span class ="ow "> in</ span > < span class ="n "> fun_control</ span > < span class ="p "> [</ span > < span class ="s2 "> "xai_methods"</ span > < span class ="p "> ]]</ span >
71327158 < span class ="n "> attributions</ span > < span class ="o "> =</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> stack</ span > < span class ="p "> (</ span > < span class ="n "> attributions_list</ span > < span class ="p "> ,</ span > < span class ="n "> axis</ span > < span class ="o "> =</ span > < span class ="mi "> 1</ span > < span class ="p "> )</ span >
71337159
7134- < span class ="k "> if</ span > < span class ="n "> fun_control</ span > < span class ="p "> [</ span > < span class ="s2 "> "xai_metric"</ span > < span class ="p "> ]</ span > < span class ="ow "> not</ span > < span class ="ow "> in</ span > < span class ="p "> {</ span > < span class ="s2 "> "max_diff"</ span > < span class ="p "> ,</ span > < span class ="s2 "> "variance"</ span > < span class ="p "> ,</ span > < span class ="s2 "> "spearman"</ span > < span class ="p "> }:</ span >
7160+ < span class ="k "> if</ span > < span class ="n "> fun_control</ span > < span class ="p "> [</ span > < span class ="s2 "> "xai_metric"</ span > < span class ="p "> ]</ span > < span class ="ow "> not</ span > < span class ="ow "> in</ span > < span class ="p "> {</ span > < span class ="s2 "> "max_diff"</ span > < span class ="p "> ,</ span > < span class ="s2 "> "variance"</ span > < span class ="p "> ,</ span > < span class ="s2 "> "spearman"</ span > < span class ="p "> , </ span > < span class =" s2 " > "spearman+variance" </ span > < span class =" p " > }:</ span >
71357161 < span class ="nb "> print</ span > < span class ="p "> (</ span > < span class ="s2 "> "Invalid or missing xai_metric. Setting it to 'max_diff'."</ span > < span class ="p "> )</ span >
71367162 < span class ="n "> fun_control</ span > < span class ="p "> [</ span > < span class ="s2 "> "xai_metric"</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="s2 "> "max_diff"</ span >
71377163
@@ -7166,6 +7192,32 @@ <h2 id="spotpython.light.trainmodel.train_model_xai" class="doc doc-heading">
71667192 < span class ="nb "> print</ span > < span class ="p "> (</ span > < span class ="s2 "> "Spearman rank correlation matrix:</ span > < span class ="se "> \n</ span > < span class ="s2 "> "</ span > < span class ="p "> ,</ span > < span class ="n "> spearman_matrix</ span > < span class ="p "> )</ span >
71677193 < span class ="nb "> print</ span > < span class ="p "> (</ span > < span class ="s2 "> "Consistency Score (Mean Spearman Correlation):"</ span > < span class ="p "> ,</ span > < span class ="o "> -</ span > < span class ="n "> result_xai</ span > < span class ="p "> )</ span >
71687194
7195+ < span class ="k "> if</ span > < span class ="n "> fun_control</ span > < span class ="p "> [</ span > < span class ="s2 "> "xai_metric"</ span > < span class ="p "> ]</ span > < span class ="o "> ==</ span > < span class ="s2 "> "spearman+variance"</ span > < span class ="p "> :</ span >
7196+ < span class ="c1 "> # Compute Spearman mean</ span >
7197+ < span class ="n "> num_methods</ span > < span class ="o "> =</ span > < span class ="n "> attributions</ span > < span class ="o "> .</ span > < span class ="n "> shape</ span > < span class ="p "> [</ span > < span class ="mi "> 1</ span > < span class ="p "> ]</ span >
7198+ < span class ="n "> spearman_matrix</ span > < span class ="o "> =</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> zeros</ span > < span class ="p "> ((</ span > < span class ="n "> num_methods</ span > < span class ="p "> ,</ span > < span class ="n "> num_methods</ span > < span class ="p "> ))</ span >
7199+
7200+ < span class ="k "> for</ span > < span class ="n "> i</ span > < span class ="ow "> in</ span > < span class ="nb "> range</ span > < span class ="p "> (</ span > < span class ="n "> num_methods</ span > < span class ="p "> ):</ span >
7201+ < span class ="k "> for</ span > < span class ="n "> j</ span > < span class ="ow "> in</ span > < span class ="nb "> range</ span > < span class ="p "> (</ span > < span class ="n "> i</ span > < span class ="o "> +</ span > < span class ="mi "> 1</ span > < span class ="p "> ,</ span > < span class ="n "> num_methods</ span > < span class ="p "> ):</ span >
7202+ < span class ="n "> corr</ span > < span class ="p "> ,</ span > < span class ="n "> _</ span > < span class ="o "> =</ span > < span class ="n "> spearmanr</ span > < span class ="p "> (</ span > < span class ="n "> attributions</ span > < span class ="p "> [:,</ span > < span class ="n "> i</ span > < span class ="p "> ],</ span > < span class ="n "> attributions</ span > < span class ="p "> [:,</ span > < span class ="n "> j</ span > < span class ="p "> ])</ span >
7203+ < span class ="n "> spearman_matrix</ span > < span class ="p "> [</ span > < span class ="n "> i</ span > < span class ="p "> ,</ span > < span class ="n "> j</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="n "> corr</ span >
7204+ < span class ="n "> spearman_matrix</ span > < span class ="p "> [</ span > < span class ="n "> j</ span > < span class ="p "> ,</ span > < span class ="n "> i</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="n "> corr</ span >
7205+
7206+ < span class ="n "> upper_triangle_values</ span > < span class ="o "> =</ span > < span class ="n "> spearman_matrix</ span > < span class ="p "> [</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> triu_indices</ span > < span class ="p "> (</ span > < span class ="n "> num_methods</ span > < span class ="p "> ,</ span > < span class ="n "> k</ span > < span class ="o "> =</ span > < span class ="mi "> 1</ span > < span class ="p "> )]</ span >
7207+ < span class ="n "> mean_spearman</ span > < span class ="o "> =</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> mean</ span > < span class ="p "> (</ span > < span class ="n "> upper_triangle_values</ span > < span class ="p "> )</ span >
7208+
7209+ < span class ="c1 "> # Compute attribution variance across methods for each feature</ span >
7210+ < span class ="n "> variance</ span > < span class ="o "> =</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> var</ span > < span class ="p "> (</ span > < span class ="n "> attributions</ span > < span class ="p "> ,</ span > < span class ="n "> axis</ span > < span class ="o "> =</ span > < span class ="mi "> 1</ span > < span class ="p "> )</ span > < span class ="o "> .</ span > < span class ="n "> mean</ span > < span class ="p "> ()</ span > < span class ="c1 "> # mean over features</ span >
7211+
7212+ < span class ="c1 "> # Combine both (λ is a trade-off hyperparameter you define)</ span >
7213+ < span class ="n "> lambda_variance</ span > < span class ="o "> =</ span > < span class ="n "> fun_control</ span > < span class ="p "> [</ span > < span class ="s2 "> "lambda_variance"</ span > < span class ="p "> ]</ span > < span class ="k "> if</ span > < span class ="s2 "> "lambda_variance"</ span > < span class ="ow "> in</ span > < span class ="n "> fun_control</ span > < span class ="k "> else</ span > < span class ="mf "> 1.0</ span >
7214+ < span class ="n "> result_xai</ span > < span class ="o "> =</ span > < span class ="o "> -</ span > < span class ="n "> mean_spearman</ span > < span class ="o "> +</ span > < span class ="n "> lambda_variance</ span > < span class ="o "> *</ span > < span class ="n "> variance</ span >
7215+
7216+ < span class ="nb "> print</ span > < span class ="p "> (</ span > < span class ="s2 "> "Mean Spearman correlation:"</ span > < span class ="p "> ,</ span > < span class ="n "> mean_spearman</ span > < span class ="p "> )</ span >
7217+ < span class ="nb "> print</ span > < span class ="p "> (</ span > < span class ="s2 "> "Mean Variance:"</ span > < span class ="p "> ,</ span > < span class ="n "> variance</ span > < span class ="p "> )</ span >
7218+ < span class ="nb "> print</ span > < span class ="p "> (</ span > < span class ="s2 "> "Variance Weight:"</ span > < span class ="p "> ,</ span > < span class ="n "> lambda_variance</ span > < span class ="p "> )</ span >
7219+ < span class ="nb "> print</ span > < span class ="p "> (</ span > < span class ="s2 "> "Combined XAI loss: "</ span > < span class ="p "> ,</ span > < span class ="n "> result_xai</ span > < span class ="p "> )</ span >
7220+
71697221 < span class ="c1 "> # -------------------------------------------------------------------------------------------------------------------</ span >
71707222
71717223 < span class ="k "> return</ span > < span class ="n "> result</ span > < span class ="p "> [</ span > < span class ="s2 "> "val_loss"</ span > < span class ="p "> ],</ span > < span class ="n "> result_xai</ span >
0 commit comments