Skip to content

Commit 02a0e50

Browse files
Deployed 4c681a1 with MkDocs version: 1.6.0
1 parent e76ab81 commit 02a0e50

6 files changed

Lines changed: 3216 additions & 2514 deletions

File tree

objects.inv

17 Bytes
Binary file not shown.

reference/spotpython/light/trainmodel/index.html

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6798,7 +6798,33 @@ <h2 id="spotpython.light.trainmodel.train_model_xai" class="doc doc-heading">
67986798
<span class="normal">751</span>
67996799
<span class="normal">752</span>
68006800
<span class="normal">753</span>
6801-
<span class="normal">754</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span> <span class="nf">train_model_xai</span><span class="p">(</span><span class="n">config</span><span class="p">:</span> <span class="nb">dict</span><span class="p">,</span> <span class="n">fun_control</span><span class="p">:</span> <span class="nb">dict</span><span class="p">,</span> <span class="n">timestamp</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">float</span><span class="p">:</span>
6801+
<span class="normal">754</span>
6802+
<span class="normal">755</span>
6803+
<span class="normal">756</span>
6804+
<span class="normal">757</span>
6805+
<span class="normal">758</span>
6806+
<span class="normal">759</span>
6807+
<span class="normal">760</span>
6808+
<span class="normal">761</span>
6809+
<span class="normal">762</span>
6810+
<span class="normal">763</span>
6811+
<span class="normal">764</span>
6812+
<span class="normal">765</span>
6813+
<span class="normal">766</span>
6814+
<span class="normal">767</span>
6815+
<span class="normal">768</span>
6816+
<span class="normal">769</span>
6817+
<span class="normal">770</span>
6818+
<span class="normal">771</span>
6819+
<span class="normal">772</span>
6820+
<span class="normal">773</span>
6821+
<span class="normal">774</span>
6822+
<span class="normal">775</span>
6823+
<span class="normal">776</span>
6824+
<span class="normal">777</span>
6825+
<span class="normal">778</span>
6826+
<span class="normal">779</span>
6827+
<span class="normal">780</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span> <span class="nf">train_model_xai</span><span class="p">(</span><span class="n">config</span><span class="p">:</span> <span class="nb">dict</span><span class="p">,</span> <span class="n">fun_control</span><span class="p">:</span> <span class="nb">dict</span><span class="p">,</span> <span class="n">timestamp</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">float</span><span class="p">:</span>
68026828
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
68036829
<span class="sd"> Trains a model using the given configuration and function control parameters. Performs feature attribution analysis and calculates consistency of these methods.</span>
68046830

@@ -7131,7 +7157,7 @@ <h2 id="spotpython.light.trainmodel.train_model_xai" class="doc doc-heading">
71317157
<span class="n">attributions_list</span> <span class="o">=</span> <span class="p">[</span><span class="n">attributions_dict</span><span class="p">[</span><span class="n">method</span><span class="p">]</span> <span class="k">for</span> <span class="n">method</span> <span class="ow">in</span> <span class="n">fun_control</span><span class="p">[</span><span class="s2">&quot;xai_methods&quot;</span><span class="p">]]</span>
71327158
<span class="n">attributions</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">attributions_list</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
71337159

7134-
<span class="k">if</span> <span class="n">fun_control</span><span class="p">[</span><span class="s2">&quot;xai_metric&quot;</span><span class="p">]</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">{</span><span class="s2">&quot;max_diff&quot;</span><span class="p">,</span> <span class="s2">&quot;variance&quot;</span><span class="p">,</span> <span class="s2">&quot;spearman&quot;</span><span class="p">}:</span>
7160+
<span class="k">if</span> <span class="n">fun_control</span><span class="p">[</span><span class="s2">&quot;xai_metric&quot;</span><span class="p">]</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">{</span><span class="s2">&quot;max_diff&quot;</span><span class="p">,</span> <span class="s2">&quot;variance&quot;</span><span class="p">,</span> <span class="s2">&quot;spearman&quot;</span><span class="p">,</span> <span class="s2">&quot;spearman+variance&quot;</span><span class="p">}:</span>
71357161
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Invalid or missing xai_metric. Setting it to &#39;max_diff&#39;.&quot;</span><span class="p">)</span>
71367162
<span class="n">fun_control</span><span class="p">[</span><span class="s2">&quot;xai_metric&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="s2">&quot;max_diff&quot;</span>
71377163

@@ -7166,6 +7192,32 @@ <h2 id="spotpython.light.trainmodel.train_model_xai" class="doc doc-heading">
71667192
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Spearman rank correlation matrix:</span><span class="se">\n</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">spearman_matrix</span><span class="p">)</span>
71677193
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Consistency Score (Mean Spearman Correlation):&quot;</span><span class="p">,</span> <span class="o">-</span><span class="n">result_xai</span><span class="p">)</span>
71687194

7195+
<span class="k">if</span> <span class="n">fun_control</span><span class="p">[</span><span class="s2">&quot;xai_metric&quot;</span><span class="p">]</span> <span class="o">==</span> <span class="s2">&quot;spearman+variance&quot;</span><span class="p">:</span>
7196+
<span class="c1"># Compute Spearman mean</span>
7197+
<span class="n">num_methods</span> <span class="o">=</span> <span class="n">attributions</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
7198+
<span class="n">spearman_matrix</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">num_methods</span><span class="p">,</span> <span class="n">num_methods</span><span class="p">))</span>
7199+
7200+
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_methods</span><span class="p">):</span>
7201+
<span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">num_methods</span><span class="p">):</span>
7202+
<span class="n">corr</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">spearmanr</span><span class="p">(</span><span class="n">attributions</span><span class="p">[:,</span> <span class="n">i</span><span class="p">],</span> <span class="n">attributions</span><span class="p">[:,</span> <span class="n">j</span><span class="p">])</span>
7203+
<span class="n">spearman_matrix</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="n">corr</span>
7204+
<span class="n">spearman_matrix</span><span class="p">[</span><span class="n">j</span><span class="p">,</span> <span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">corr</span>
7205+
7206+
<span class="n">upper_triangle_values</span> <span class="o">=</span> <span class="n">spearman_matrix</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">triu_indices</span><span class="p">(</span><span class="n">num_methods</span><span class="p">,</span> <span class="n">k</span><span class="o">=</span><span class="mi">1</span><span class="p">)]</span>
7207+
<span class="n">mean_spearman</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">upper_triangle_values</span><span class="p">)</span>
7208+
7209+
<span class="c1"># Compute attribution variance across methods for each feature</span>
7210+
<span class="n">variance</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">var</span><span class="p">(</span><span class="n">attributions</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span> <span class="c1"># mean over features</span>
7211+
7212+
<span class="c1"># Combine both (λ is a trade-off hyperparameter you define)</span>
7213+
<span class="n">lambda_variance</span> <span class="o">=</span> <span class="n">fun_control</span><span class="p">[</span><span class="s2">&quot;lambda_variance&quot;</span><span class="p">]</span> <span class="k">if</span> <span class="s2">&quot;lambda_variance&quot;</span> <span class="ow">in</span> <span class="n">fun_control</span> <span class="k">else</span> <span class="mf">1.0</span>
7214+
<span class="n">result_xai</span> <span class="o">=</span> <span class="o">-</span><span class="n">mean_spearman</span> <span class="o">+</span> <span class="n">lambda_variance</span> <span class="o">*</span> <span class="n">variance</span>
7215+
7216+
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Mean Spearman correlation:&quot;</span><span class="p">,</span> <span class="n">mean_spearman</span><span class="p">)</span>
7217+
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Mean Variance:&quot;</span><span class="p">,</span> <span class="n">variance</span><span class="p">)</span>
7218+
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Variance Weight:&quot;</span><span class="p">,</span> <span class="n">lambda_variance</span><span class="p">)</span>
7219+
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Combined XAI loss: &quot;</span><span class="p">,</span> <span class="n">result_xai</span><span class="p">)</span>
7220+
71697221
<span class="c1"># -------------------------------------------------------------------------------------------------------------------</span>
71707222

71717223
<span class="k">return</span> <span class="n">result</span><span class="p">[</span><span class="s2">&quot;val_loss&quot;</span><span class="p">],</span> <span class="n">result_xai</span>

0 commit comments

Comments
 (0)