|
7 | 7 | import torch.nn as nn |
8 | 8 | import torch.nn.functional as F |
9 | 9 | import matplotlib.colors as colors |
| 10 | +from spotPython.hyperparameters.values import get_tuned_architecture |
| 11 | +from spotPython.light.trainmodel import train_model |
| 12 | +from spotPython.light.loadmodel import load_light_from_checkpoint |
| 13 | +from spotPython.utils.classes import get_removed_attributes_and_base_net |
| 14 | +import pandas as pd |
| 15 | +from captum.attr import LayerConductance, LayerActivation, LayerIntegratedGradients |
| 16 | +from captum.attr import IntegratedGradients, DeepLift, GradientShap, NoiseTunnel, FeatureAblation |
10 | 17 |
|
11 | 18 |
|
12 | 19 | def get_activations(net, fun_control, batch_size, device="cpu") -> dict: |
@@ -472,3 +479,119 @@ def visualize_gradients(net, fun_control, batch_size, absolute=True, cmap="gray" |
472 | 479 | batch_size=batch_size, |
473 | 480 | ) |
474 | 481 | plot_nn_values_scatter(nn_values=grads, nn_values_names="Gradients", absolute=absolute, cmap=cmap, figsize=figsize) |
| 482 | + |
| 483 | + |
| 484 | +def get_attributions( |
| 485 | + spot_tuner, |
| 486 | + fun_control, |
| 487 | + attr_method="IntegratedGradients", |
| 488 | + baseline=None, |
| 489 | + abs_attr=True, |
| 490 | + n_rel=5, |
| 491 | + feature_names=None, |
| 492 | +): |
| 493 | + """Get the attributions of a neural network. |
| 494 | +
|
| 495 | + Args: |
| 496 | + spot_tuner (object): |
| 497 | + The spot tuner object. |
| 498 | + fun_control (dict): |
| 499 | + A dictionary with the function control. |
| 500 | + attr_method (str, optional): |
| 501 | + The attribution method. Defaults to "IntegratedGradients". |
| 502 | + baseline (torch.Tensor, optional): |
| 503 | + The baseline for the attribution methods. Defaults to None. |
| 504 | + abs_attr (bool, optional): |
| 505 | + Whether the method should sort by the absolute attribution values. Defaults to True. |
| 506 | + n_rel (int, optional): |
| 507 | + The number of relevant features. Defaults to 5. |
| 508 | + feature_names (list, optional): |
| 509 | + The feature names. Defaults to None. |
| 510 | +
|
| 511 | + Returns: |
| 512 | + pd.DataFrame: A DataFrame with the attributions. |
| 513 | + """ |
| 514 | + total_attributions = None |
| 515 | + config = get_tuned_architecture(spot_tuner, fun_control) |
| 516 | + train_model(config, fun_control, timestamp=False) |
| 517 | + model_loaded = load_light_from_checkpoint(config, fun_control, postfix="_TRAIN") |
| 518 | + removed_attributes, model = get_removed_attributes_and_base_net(net=model_loaded) |
| 519 | + model = model.to("cpu") |
| 520 | + model.eval() |
| 521 | + dataset = fun_control["data_set"] |
| 522 | + n_features = dataset.data.shape[1] |
| 523 | + if feature_names is None: |
| 524 | + feature_names = [f"x{i}" for i in range(n_features)] |
| 525 | + batch_size = config["batch_size"] |
| 526 | + # train_loader = DataLoader(dataset, batch_size=batch_size) |
| 527 | + test_loader = DataLoader(dataset, batch_size=batch_size) |
| 528 | + if attr_method == "IntegratedGradients": |
| 529 | + attr = IntegratedGradients(model) |
| 530 | + elif attr_method == "DeepLift": |
| 531 | + attr = DeepLift(model) |
| 532 | + elif attr_method == "GradientShap": # Todo: would need a baseline |
| 533 | + if baseline is None: |
| 534 | + raise ValueError("baseline cannot be 'None' for GradientShap") |
| 535 | + attr = GradientShap(model) |
| 536 | + elif attr_method == "FeatureAblation": |
| 537 | + attr = FeatureAblation(model) |
| 538 | + else: |
| 539 | + raise ValueError( |
| 540 | + """ |
| 541 | + Unsupported attribution method. |
| 542 | + Please choose from 'IntegratedGradients', 'DeepLift', 'GradientShap', or 'FeatureAblation'. |
| 543 | + """ |
| 544 | + ) |
| 545 | + for inputs, labels in test_loader: |
| 546 | + attributions = attr.attribute(inputs, return_convergence_delta=False, baselines=baseline) |
| 547 | + if total_attributions is None: |
| 548 | + total_attributions = attributions |
| 549 | + else: |
| 550 | + if len(attributions) == len(total_attributions): |
| 551 | + total_attributions += attributions |
| 552 | + |
| 553 | + # Calculation of average attribution across all batches |
| 554 | + avg_attributions = total_attributions.mean(dim=0).detach().numpy() |
| 555 | + |
| 556 | + # Transformation to the absolute attribution values if abs_attr is True |
| 557 | + # Get indices of the n most important features |
| 558 | + if abs_attr is True: |
| 559 | + abs_avg_attributions = abs(avg_attributions) |
| 560 | + top_n_indices = abs_avg_attributions.argsort()[-n_rel:][::-1] |
| 561 | + else: |
| 562 | + top_n_indices = avg_attributions.argsort()[-n_rel:][::-1] |
| 563 | + |
| 564 | + # Get the importance values for the top n features |
| 565 | + top_n_importances = avg_attributions[top_n_indices] |
| 566 | + |
| 567 | + df = pd.DataFrame( |
| 568 | + { |
| 569 | + "Feature Index": top_n_indices, |
| 570 | + "Feature": [feature_names[i] for i in top_n_indices], |
| 571 | + attr_method + "Attribution": top_n_importances, |
| 572 | + } |
| 573 | + ) |
| 574 | + return df |
| 575 | + |
| 576 | + |
| 577 | +def plot_attributions(df, attr_method="IntegratedGradients"): |
| 578 | + """ |
| 579 | + Plot the attributions of a neural network. |
| 580 | +
|
| 581 | + Args: |
| 582 | + df (pd.DataFrame): |
| 583 | + A DataFrame with the attributions. |
| 584 | + attr_method (str, optional): |
| 585 | + The attribution method. Defaults to "IntegratedGradients". |
| 586 | +
|
| 587 | + Returns: |
| 588 | + None |
| 589 | +
|
| 590 | + """ |
| 591 | + sns.set_theme(style="whitegrid") |
| 592 | + plt.figure(figsize=(10, 6)) |
| 593 | + sns.barplot(x=attr_method + "Attribution", y="Feature", data=df, palette="viridis", hue="Feature") |
| 594 | + plt.title(f"Top {df.shape[0]} Features by {attr_method} Attribution") |
| 595 | + plt.xlabel(f"{attr_method} Attribution Value") |
| 596 | + plt.ylabel("Feature") |
| 597 | + plt.show() |
0 commit comments