Skip to content

Commit c37afbc

Browse files
committed
nn funnel classifier implementation
1 parent 1d9d3b6 commit c37afbc

3 files changed

Lines changed: 344 additions & 19 deletions

File tree

src/spotpython/hyperdict/light_hyper_dict.json

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1123,5 +1123,132 @@
11231123
"lower": 0,
11241124
"upper": 3
11251125
}
1126+
},
1127+
"NNFunnelClassifier": {
1128+
"l1": {
1129+
"type": "int",
1130+
"default": 3,
1131+
"transform": "transform_power_2_int",
1132+
"lower": 3,
1133+
"upper": 8
1134+
},
1135+
"num_layers": {
1136+
"type": "int",
1137+
"default": 3,
1138+
"transform": "None",
1139+
"lower": 2,
1140+
"upper": 10
1141+
},
1142+
"epochs": {
1143+
"type": "int",
1144+
"default": 4,
1145+
"transform": "transform_power_2_int",
1146+
"lower": 4,
1147+
"upper": 9
1148+
},
1149+
"batch_size": {
1150+
"type": "int",
1151+
"default": 4,
1152+
"transform": "transform_power_2_int",
1153+
"lower": 1,
1154+
"upper": 4
1155+
},
1156+
"act_fn": {
1157+
"levels": [
1158+
"Tanh",
1159+
"ReLU",
1160+
"LeakyReLU",
1161+
"ELU",
1162+
"Swish"
1163+
],
1164+
"type": "factor",
1165+
"default": "ReLU",
1166+
"transform": "None",
1167+
"class_name": "spotpython.torch.activation",
1168+
"core_model_parameter_type": "instance()",
1169+
"lower": 0,
1170+
"upper": 5
1171+
},
1172+
"optimizer": {
1173+
"levels": [
1174+
"Adadelta",
1175+
"Adagrad",
1176+
"Adam",
1177+
"AdamW",
1178+
"SparseAdam",
1179+
"Adamax",
1180+
"ASGD",
1181+
"NAdam",
1182+
"RAdam",
1183+
"RMSprop",
1184+
"Rprop",
1185+
"SGD"
1186+
],
1187+
"type": "factor",
1188+
"default": "SGD",
1189+
"transform": "None",
1190+
"class_name": "torch.optim",
1191+
"core_model_parameter_type": "str",
1192+
"lower": 0,
1193+
"upper": 11
1194+
},
1195+
"dropout_prob": {
1196+
"type": "float",
1197+
"default": 0.01,
1198+
"transform": "None",
1199+
"lower": 0.0,
1200+
"upper": 0.25
1201+
},
1202+
"lr_mult": {
1203+
"type": "float",
1204+
"default": 1.0,
1205+
"transform": "None",
1206+
"lower": 0.1,
1207+
"upper": 10.0
1208+
},
1209+
"patience": {
1210+
"type": "int",
1211+
"default": 2,
1212+
"transform": "transform_power_2_int",
1213+
"lower": 2,
1214+
"upper": 6
1215+
},
1216+
"initialization": {
1217+
"levels": [
1218+
"Default",
1219+
"Kaiming",
1220+
"Xavier"
1221+
],
1222+
"type": "factor",
1223+
"default": "Default",
1224+
"transform": "None",
1225+
"core_model_parameter_type": "str",
1226+
"lower": 0,
1227+
"upper": 2
1228+
},
1229+
"batch_norm": {
1230+
"levels": [
1231+
0,
1232+
1
1233+
],
1234+
"type": "factor",
1235+
"default": 0,
1236+
"transform": "None",
1237+
"core_model_parameter_type": "bool",
1238+
"lower": 0,
1239+
"upper": 1
1240+
},
1241+
"lr_sched": {
1242+
"levels": [
1243+
0,
1244+
1
1245+
],
1246+
"type": "factor",
1247+
"default": 0,
1248+
"transform": "None",
1249+
"core_model_parameter_type": "bool",
1250+
"lower": 0,
1251+
"upper": 1
1252+
}
11261253
}
11271254
}
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
import lightning as L
2+
import torch
3+
import torch.nn.functional as F
4+
from torch import nn
5+
from spotpython.hyperparameters.optimizer import optimizer_handler
6+
import torchmetrics.functional.classification as TMclf
7+
import torch.optim as optim
8+
9+
10+
class NNFunnelClassifier(L.LightningModule):
11+
"""
12+
Funnel-shaped MLP for classification (binary & multiclass).
13+
14+
Attributes:
15+
l1 (int): neurons in first hidden layer.
16+
num_layers (int): number of hidden layers.
17+
epochs (int): number of training epochs (used for LR scheduler milestones).
18+
batch_size (int): batch size (used for example_input_array).
19+
initialization (str): (keine direkte Nutzung hier – identisch zur Vorlage).
20+
act_fn (nn.Module): activation module (keine Ignorierung; bleibt tunebar).
21+
optimizer (str): optimizer name for optimizer_handler.
22+
dropout_prob (float): dropout probability.
23+
lr_mult (float): learning-rate multiplier (passed to optimizer_handler).
24+
patience (int): (nicht in dieser Klasse verwendet – wie Vorlage).
25+
_L_in (int): input dimension.
26+
_L_out (int): number of classes. If 1 => binary, else multiclass.
27+
_torchmetric (str): optional metric name ("accuracy" default). Used for logging, not as loss.
28+
layers (nn.Sequential): the network.
29+
"""
30+
31+
def __init__(
32+
self,
33+
l1: int,
34+
num_layers: int,
35+
epochs: int,
36+
batch_size: int,
37+
initialization: str,
38+
act_fn: nn.Module,
39+
optimizer: str,
40+
dropout_prob: float,
41+
lr_mult: float,
42+
patience: int,
43+
_L_in: int,
44+
_L_out: int,
45+
_torchmetric: str,
46+
*args,
47+
**kwargs,
48+
):
49+
super().__init__()
50+
self._L_in = _L_in
51+
self._L_out = _L_out
52+
53+
# Metric (default accuracy) for logging
54+
# Loss is always BCEWithLogitsLoss or CrossEntropyLoss
55+
if _torchmetric is None:
56+
_torchmetric = "accuracy"
57+
self._torchmetric = _torchmetric.lower()
58+
59+
self._is_binary = self._L_out == 1
60+
61+
self.save_hyperparameters(ignore=["_L_in", "_L_out", "_torchmetric"])
62+
63+
# Dummy-Input für Graph
64+
self.example_input_array = torch.zeros((batch_size, self._L_in))
65+
66+
if self.hparams.l1 < 8:
67+
raise ValueError("l1 must be at least 8")
68+
69+
# Netzwerk wie in deiner Vorlage (Funnel, optional BatchNorm/Dropout)
70+
layers = []
71+
in_features = self._L_in
72+
hidden_size = self.hparams.l1
73+
out_dim = 1 if self._is_binary else self._L_out
74+
75+
for _ in range(self.hparams.num_layers):
76+
out_features = max(hidden_size // 2, 8) # min 8
77+
layers.append(nn.Linear(in_features, hidden_size))
78+
79+
if getattr(self.hparams, "batch_norm", False):
80+
layers.append(nn.BatchNorm1d(hidden_size))
81+
82+
layers.append(self.hparams.act_fn)
83+
layers.append(nn.Dropout(self.hparams.dropout_prob))
84+
85+
in_features = hidden_size
86+
hidden_size = out_features
87+
88+
layers.append(nn.Linear(in_features, out_dim))
89+
self.layers = nn.Sequential(*layers)
90+
91+
# Loss nach Task
92+
if self._is_binary:
93+
# Combined Sigmoid + BCE
94+
self._criterion = nn.BCEWithLogitsLoss()
95+
else:
96+
# Combined Softmax + CE
97+
self._criterion = nn.CrossEntropyLoss()
98+
99+
def forward(self, x: torch.Tensor) -> torch.Tensor:
100+
"""
101+
Returns raw logits. For binary: shape (N,1). For multiclass: (N,C).
102+
"""
103+
return self.layers(x)
104+
105+
# internal helper to compute loss and metric
106+
def _calculate_loss_and_metric(self, batch):
107+
x, y = batch
108+
logits = self(x)
109+
110+
if self._is_binary:
111+
# y -> (N,1) float
112+
y_t = y.view(-1, 1).float()
113+
loss = self._criterion(logits, y_t)
114+
# Für Metriken bereiten wir Schwellen-Preds vor
115+
probs = torch.sigmoid(logits).view(-1)
116+
preds = (probs >= 0.5).long()
117+
target = y.view(-1).long()
118+
else:
119+
# CE expected Long targets (N,) with class indices
120+
loss = self._criterion(logits, y.long())
121+
probs = torch.softmax(logits, dim=1)
122+
preds = torch.argmax(probs, dim=1)
123+
target = y.long()
124+
125+
# metrices
126+
metric_value = None
127+
try:
128+
if self._torchmetric == "accuracy":
129+
if self._is_binary:
130+
# binary accuracy (0/1)
131+
metric_value = TMclf.accuracy(preds, target, task="binary")
132+
else:
133+
metric_value = TMclf.accuracy(preds, target, task="multiclass", num_classes=self._L_out)
134+
else:
135+
# TBC: implement other metrics
136+
pass
137+
except Exception:
138+
metric_value = None
139+
140+
return loss, metric_value
141+
142+
# --- Lightning Hooks ---
143+
def training_step(self, batch: tuple) -> torch.Tensor:
144+
loss, _ = self._calculate_loss_and_metric(batch)
145+
return loss
146+
147+
def validation_step(self, batch: tuple, batch_idx: int, prog_bar: bool = False) -> torch.Tensor:
148+
val_loss, val_metric = self._calculate_loss_and_metric(batch)
149+
self.log("val_loss", val_loss, prog_bar=prog_bar)
150+
self.log("hp_metric", val_loss, prog_bar=prog_bar)
151+
if val_metric is not None:
152+
self.log(f"val_{self._torchmetric}", val_metric, prog_bar=prog_bar)
153+
return val_loss
154+
155+
def test_step(self, batch: tuple, batch_idx: int, prog_bar: bool = False) -> torch.Tensor:
156+
test_loss, test_metric = self._calculate_loss_and_metric(batch)
157+
self.log("test_loss", test_loss, prog_bar=prog_bar)
158+
self.log("hp_metric", test_loss, prog_bar=prog_bar)
159+
if test_metric is not None:
160+
self.log(f"test_{self._torchmetric}", test_metric, prog_bar=prog_bar)
161+
return test_loss
162+
163+
def predict_step(self, batch: tuple, batch_idx: int, prog_bar: bool = False):
164+
x, y = batch
165+
logits = self(x)
166+
if self._is_binary:
167+
probs = torch.sigmoid(logits).view(-1, 1) # (N,1)
168+
preds = (probs >= 0.5).long()
169+
else:
170+
probs = torch.softmax(logits, dim=1) # (N,C)
171+
preds = torch.argmax(probs, dim=1, keepdim=True)
172+
# Debug-Ausgaben wie bei dir:
173+
print(f"Predict step x: {x}")
174+
print(f"Predict step y: {y}")
175+
print(f"Predict step logits: {logits}")
176+
print(f"Predict step probs: {probs}")
177+
print(f"Predict step preds: {preds}")
178+
return (x, y, logits, probs, preds)
179+
180+
def configure_optimizers(self) -> torch.optim.Optimizer:
181+
optimizer = optimizer_handler(optimizer_name=self.hparams.optimizer, params=self.parameters(), lr_mult=self.hparams.lr_mult)
182+
183+
if getattr(self.hparams, "lr_sched", False):
184+
num_milestones = 3
185+
milestones = [int(self.hparams.epochs / (num_milestones + 1) * (i + 1)) for i in range(num_milestones)]
186+
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)
187+
lr_scheduler_config = {
188+
"scheduler": scheduler,
189+
"interval": "epoch",
190+
"frequency": 1,
191+
}
192+
return {
193+
"optimizer": optimizer,
194+
"lr_scheduler": lr_scheduler_config,
195+
}
196+
else:
197+
return optimizer

src/spotpython/light/trainmodel.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,7 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
693693
X_train_tensor = torch.cat(X_train_list, dim=0).to(model.device)
694694
X_train_tensor.requires_grad_()
695695
X_val_tensor = torch.cat(X_val_list, dim=0).to(model.device)
696-
X_val_tensor.requires_grad_()
696+
X_val_tensor.requires_grad_()
697697

698698
# Dictionary to store attributions
699699
attributions_dict = {}
@@ -709,38 +709,39 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
709709
N_total = X_val_tensor.size(0)
710710
N_attr = min(fun_control["xai_subset_size"], N_total)
711711
print(f"Using a subset of {N_attr} samples for attribution analysis out of {N_total} total samples.")
712-
g = torch.Generator(device=X_val_tensor.device)
713-
g.manual_seed(fun_control["seed"])
714-
perm = torch.randperm(N_total, generator=g,
715-
device=X_val_tensor.device)[:N_attr]
716-
X_val_tensor = X_val_tensor[perm]
712+
g = torch.Generator(device=X_val_tensor.device)
713+
g.manual_seed(fun_control["seed"])
714+
perm = torch.randperm(N_total, generator=g, device=X_val_tensor.device)[:N_attr]
715+
X_val_tensor = X_val_tensor[perm]
717716

718-
# Ensure the model is in evaluation mode
717+
# Ensure the model is in evaluation mode
719718
model.eval()
720719

720+
target = fun_control.get("xai_target", None)
721+
721722
if "KernelShap" in fun_control["xai_methods"]:
722-
attr_ks = KernelShap(model)
723-
n_features = X_val_tensor.shape[1]
724-
samples_ks = min(2000, 100 * n_features) # Adjust number of samples based on features, maximum 2000
725-
print("KernelShap: Using", samples_ks, "samples for attribution.")
726-
with torch.no_grad():
727-
attribution_ks = attr_ks.attribute(X_val_tensor, baselines=baseline, n_samples=samples_ks, perturbations_per_eval=64)
728-
ks_attr_test_sum = attribution_ks.detach().numpy().sum(axis=0)
729-
l2_norm = np.linalg.norm(ks_attr_test_sum)
730-
l2_normalized_ks = ks_attr_test_sum / l2_norm if l2_norm != 0 else ks_attr_test_sum
731-
attributions_dict["KernelShap"] = l2_normalized_ks
723+
attr_ks = KernelShap(model)
724+
n_features = X_val_tensor.shape[1]
725+
samples_ks = min(2000, 100 * n_features) # Adjust number of samples based on features, maximum 2000
726+
print("KernelShap: Using", samples_ks, "samples for attribution.")
727+
with torch.no_grad():
728+
attribution_ks = attr_ks.attribute(X_val_tensor, baselines=baseline, n_samples=samples_ks, perturbations_per_eval=64, target=target)
729+
ks_attr_test_sum = attribution_ks.detach().numpy().sum(axis=0)
730+
l2_norm = np.linalg.norm(ks_attr_test_sum)
731+
l2_normalized_ks = ks_attr_test_sum / l2_norm if l2_norm != 0 else ks_attr_test_sum
732+
attributions_dict["KernelShap"] = l2_normalized_ks
732733

733734
with torch.enable_grad():
734735
if "IntegratedGradients" in fun_control["xai_methods"]:
735736
attr_ig = IntegratedGradients(model)
736-
attribution_ig = attr_ig.attribute(X_val_tensor, baselines=baseline)
737+
attribution_ig = attr_ig.attribute(X_val_tensor, baselines=baseline, target=target)
737738
vec = attribution_ig.detach().cpu().numpy().sum(axis=0)
738739
l2 = np.linalg.norm(vec)
739740
attributions_dict["IntegratedGradients"] = vec / l2 if l2 != 0 else vec
740741

741742
if "DeepLift" in fun_control["xai_methods"]:
742743
attr_dl = DeepLift(model)
743-
attribution_dl = attr_dl.attribute(X_val_tensor, baselines=baseline)
744+
attribution_dl = attr_dl.attribute(X_val_tensor, baselines=baseline, target=target)
744745
dl_attr_test_sum = attribution_dl.detach().numpy().sum(axis=0)
745746
l2_norm = np.linalg.norm(dl_attr_test_sum)
746747
l2_normalized_dl = dl_attr_test_sum / l2_norm if l2_norm != 0 else dl_attr_test_sum

0 commit comments

Comments
 (0)