|
| 1 | +import lightning as L |
| 2 | +import torch |
| 3 | +import torch.nn.functional as F |
| 4 | +from torch import nn |
| 5 | +from spotpython.hyperparameters.optimizer import optimizer_handler |
| 6 | +import torchmetrics.functional.classification as TMclf |
| 7 | +import torch.optim as optim |
| 8 | + |
| 9 | + |
| 10 | +class NNFunnelClassifier(L.LightningModule): |
| 11 | + """ |
| 12 | + Funnel-shaped MLP for classification (binary & multiclass). |
| 13 | +
|
| 14 | + Attributes: |
| 15 | + l1 (int): neurons in first hidden layer. |
| 16 | + num_layers (int): number of hidden layers. |
| 17 | + epochs (int): number of training epochs (used for LR scheduler milestones). |
| 18 | + batch_size (int): batch size (used for example_input_array). |
| 19 | + initialization (str): (keine direkte Nutzung hier – identisch zur Vorlage). |
| 20 | + act_fn (nn.Module): activation module (keine Ignorierung; bleibt tunebar). |
| 21 | + optimizer (str): optimizer name for optimizer_handler. |
| 22 | + dropout_prob (float): dropout probability. |
| 23 | + lr_mult (float): learning-rate multiplier (passed to optimizer_handler). |
| 24 | + patience (int): (nicht in dieser Klasse verwendet – wie Vorlage). |
| 25 | + _L_in (int): input dimension. |
| 26 | + _L_out (int): number of classes. If 1 => binary, else multiclass. |
| 27 | + _torchmetric (str): optional metric name ("accuracy" default). Used for logging, not as loss. |
| 28 | + layers (nn.Sequential): the network. |
| 29 | + """ |
| 30 | + |
| 31 | + def __init__( |
| 32 | + self, |
| 33 | + l1: int, |
| 34 | + num_layers: int, |
| 35 | + epochs: int, |
| 36 | + batch_size: int, |
| 37 | + initialization: str, |
| 38 | + act_fn: nn.Module, |
| 39 | + optimizer: str, |
| 40 | + dropout_prob: float, |
| 41 | + lr_mult: float, |
| 42 | + patience: int, |
| 43 | + _L_in: int, |
| 44 | + _L_out: int, |
| 45 | + _torchmetric: str, |
| 46 | + *args, |
| 47 | + **kwargs, |
| 48 | + ): |
| 49 | + super().__init__() |
| 50 | + self._L_in = _L_in |
| 51 | + self._L_out = _L_out |
| 52 | + |
| 53 | + # Metric (default accuracy) for logging |
| 54 | + # Loss is always BCEWithLogitsLoss or CrossEntropyLoss |
| 55 | + if _torchmetric is None: |
| 56 | + _torchmetric = "accuracy" |
| 57 | + self._torchmetric = _torchmetric.lower() |
| 58 | + |
| 59 | + self._is_binary = self._L_out == 1 |
| 60 | + |
| 61 | + self.save_hyperparameters(ignore=["_L_in", "_L_out", "_torchmetric"]) |
| 62 | + |
| 63 | + # Dummy-Input für Graph |
| 64 | + self.example_input_array = torch.zeros((batch_size, self._L_in)) |
| 65 | + |
| 66 | + if self.hparams.l1 < 8: |
| 67 | + raise ValueError("l1 must be at least 8") |
| 68 | + |
| 69 | + # Netzwerk wie in deiner Vorlage (Funnel, optional BatchNorm/Dropout) |
| 70 | + layers = [] |
| 71 | + in_features = self._L_in |
| 72 | + hidden_size = self.hparams.l1 |
| 73 | + out_dim = 1 if self._is_binary else self._L_out |
| 74 | + |
| 75 | + for _ in range(self.hparams.num_layers): |
| 76 | + out_features = max(hidden_size // 2, 8) # min 8 |
| 77 | + layers.append(nn.Linear(in_features, hidden_size)) |
| 78 | + |
| 79 | + if getattr(self.hparams, "batch_norm", False): |
| 80 | + layers.append(nn.BatchNorm1d(hidden_size)) |
| 81 | + |
| 82 | + layers.append(self.hparams.act_fn) |
| 83 | + layers.append(nn.Dropout(self.hparams.dropout_prob)) |
| 84 | + |
| 85 | + in_features = hidden_size |
| 86 | + hidden_size = out_features |
| 87 | + |
| 88 | + layers.append(nn.Linear(in_features, out_dim)) |
| 89 | + self.layers = nn.Sequential(*layers) |
| 90 | + |
| 91 | + # Loss nach Task |
| 92 | + if self._is_binary: |
| 93 | + # Combined Sigmoid + BCE |
| 94 | + self._criterion = nn.BCEWithLogitsLoss() |
| 95 | + else: |
| 96 | + # Combined Softmax + CE |
| 97 | + self._criterion = nn.CrossEntropyLoss() |
| 98 | + |
| 99 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 100 | + """ |
| 101 | + Returns raw logits. For binary: shape (N,1). For multiclass: (N,C). |
| 102 | + """ |
| 103 | + return self.layers(x) |
| 104 | + |
| 105 | + # internal helper to compute loss and metric |
| 106 | + def _calculate_loss_and_metric(self, batch): |
| 107 | + x, y = batch |
| 108 | + logits = self(x) |
| 109 | + |
| 110 | + if self._is_binary: |
| 111 | + # y -> (N,1) float |
| 112 | + y_t = y.view(-1, 1).float() |
| 113 | + loss = self._criterion(logits, y_t) |
| 114 | + # Für Metriken bereiten wir Schwellen-Preds vor |
| 115 | + probs = torch.sigmoid(logits).view(-1) |
| 116 | + preds = (probs >= 0.5).long() |
| 117 | + target = y.view(-1).long() |
| 118 | + else: |
| 119 | + # CE expected Long targets (N,) with class indices |
| 120 | + loss = self._criterion(logits, y.long()) |
| 121 | + probs = torch.softmax(logits, dim=1) |
| 122 | + preds = torch.argmax(probs, dim=1) |
| 123 | + target = y.long() |
| 124 | + |
| 125 | + # metrices |
| 126 | + metric_value = None |
| 127 | + try: |
| 128 | + if self._torchmetric == "accuracy": |
| 129 | + if self._is_binary: |
| 130 | + # binary accuracy (0/1) |
| 131 | + metric_value = TMclf.accuracy(preds, target, task="binary") |
| 132 | + else: |
| 133 | + metric_value = TMclf.accuracy(preds, target, task="multiclass", num_classes=self._L_out) |
| 134 | + else: |
| 135 | + # TBC: implement other metrics |
| 136 | + pass |
| 137 | + except Exception: |
| 138 | + metric_value = None |
| 139 | + |
| 140 | + return loss, metric_value |
| 141 | + |
| 142 | + # --- Lightning Hooks --- |
| 143 | + def training_step(self, batch: tuple) -> torch.Tensor: |
| 144 | + loss, _ = self._calculate_loss_and_metric(batch) |
| 145 | + return loss |
| 146 | + |
| 147 | + def validation_step(self, batch: tuple, batch_idx: int, prog_bar: bool = False) -> torch.Tensor: |
| 148 | + val_loss, val_metric = self._calculate_loss_and_metric(batch) |
| 149 | + self.log("val_loss", val_loss, prog_bar=prog_bar) |
| 150 | + self.log("hp_metric", val_loss, prog_bar=prog_bar) |
| 151 | + if val_metric is not None: |
| 152 | + self.log(f"val_{self._torchmetric}", val_metric, prog_bar=prog_bar) |
| 153 | + return val_loss |
| 154 | + |
| 155 | + def test_step(self, batch: tuple, batch_idx: int, prog_bar: bool = False) -> torch.Tensor: |
| 156 | + test_loss, test_metric = self._calculate_loss_and_metric(batch) |
| 157 | + self.log("test_loss", test_loss, prog_bar=prog_bar) |
| 158 | + self.log("hp_metric", test_loss, prog_bar=prog_bar) |
| 159 | + if test_metric is not None: |
| 160 | + self.log(f"test_{self._torchmetric}", test_metric, prog_bar=prog_bar) |
| 161 | + return test_loss |
| 162 | + |
| 163 | + def predict_step(self, batch: tuple, batch_idx: int, prog_bar: bool = False): |
| 164 | + x, y = batch |
| 165 | + logits = self(x) |
| 166 | + if self._is_binary: |
| 167 | + probs = torch.sigmoid(logits).view(-1, 1) # (N,1) |
| 168 | + preds = (probs >= 0.5).long() |
| 169 | + else: |
| 170 | + probs = torch.softmax(logits, dim=1) # (N,C) |
| 171 | + preds = torch.argmax(probs, dim=1, keepdim=True) |
| 172 | + # Debug-Ausgaben wie bei dir: |
| 173 | + print(f"Predict step x: {x}") |
| 174 | + print(f"Predict step y: {y}") |
| 175 | + print(f"Predict step logits: {logits}") |
| 176 | + print(f"Predict step probs: {probs}") |
| 177 | + print(f"Predict step preds: {preds}") |
| 178 | + return (x, y, logits, probs, preds) |
| 179 | + |
| 180 | + def configure_optimizers(self) -> torch.optim.Optimizer: |
| 181 | + optimizer = optimizer_handler(optimizer_name=self.hparams.optimizer, params=self.parameters(), lr_mult=self.hparams.lr_mult) |
| 182 | + |
| 183 | + if getattr(self.hparams, "lr_sched", False): |
| 184 | + num_milestones = 3 |
| 185 | + milestones = [int(self.hparams.epochs / (num_milestones + 1) * (i + 1)) for i in range(num_milestones)] |
| 186 | + scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1) |
| 187 | + lr_scheduler_config = { |
| 188 | + "scheduler": scheduler, |
| 189 | + "interval": "epoch", |
| 190 | + "frequency": 1, |
| 191 | + } |
| 192 | + return { |
| 193 | + "optimizer": optimizer, |
| 194 | + "lr_scheduler": lr_scheduler_config, |
| 195 | + } |
| 196 | + else: |
| 197 | + return optimizer |
0 commit comments