33import torch .nn .functional as F
44from torch import nn
55from spotPython .hyperparameters .optimizer import optimizer_handler
6+ import torchmetrics .functional .regression
67
78
89class RNNLightRegression (L .LightningModule ):
@@ -32,6 +33,8 @@ class RNNLightRegression(L.LightningModule):
3233 The number of input features.
3334 _L_out (int):
3435 The number of output classes.
36+ _torchmetric (str):
37+ The metric to use for the loss function, e.g., "mean_squared_error".
3538 layers (nn.Sequential):
3639 The neural network model.
3740
@@ -104,6 +107,7 @@ def __init__(
104107 patience : int ,
105108 _L_in : int ,
106109 _L_out : int ,
110+ _torchmetric : str ,
107111 ):
108112 """
109113 Initializes the NetLightRegression object.
@@ -120,6 +124,7 @@ def __init__(
120124 patience (int): The number of epochs to wait before early stopping.
121125 _L_in (int): The number of input features. Not a hyperparameter, but needed to create the network.
122126 _L_out (int): The number of output classes. Not a hyperparameter, but needed to create the network.
127+ _torchmetric (str): The metric to use for the loss function, e.g., "mean_squared_error".
123128
124129 Returns:
125130 (NoneType): None
@@ -133,8 +138,10 @@ def __init__(
133138 #
134139 self ._L_in = _L_in
135140 self ._L_out = _L_out
141+ self ._torchmetric = _torchmetric
136142 # _L_in and _L_out are not hyperparameters, but are needed to create the network
137- self .save_hyperparameters (ignore = ["_L_in" , "_L_out" ])
143+ # _torchmetric is not a hyperparameter, but is needed to calculate the loss
144+ self .save_hyperparameters (ignore = ["_L_in" , "_L_out" , "_torchmetric" ])
138145 # set dummy input array for Tensorboard Graphs
139146 # set log_graph=True in Trainer to see the graph (in traintest.py)
140147 self .example_input_array = torch .zeros ((batch_size , self ._L_in ))
@@ -224,7 +231,9 @@ def training_step(self, batch: tuple, prog_bar: bool = False) -> torch.Tensor:
224231 # Note: the number of rows in x is equal to the number of rows in y
225232 y_hat = self (x )
226233 # Note: the number of rows in y_hat is equal to the number of rows in y
227- train_loss = F .mse_loss (y_hat , y )
234+ # train_loss = F.mse_loss(y_hat, y)
235+ metric = getattr (torchmetrics .functional .regression , self ._torchmetric )
236+ train_loss = metric (y_hat , y )
228237 # mae_loss = F.l1_loss(y_hat, y)
229238 # self.log("train_loss", val_loss, on_step=True, on_epoch=True, prog_bar=True)
230239 # self.log("train_mae_loss", mae_loss, on_step=True, on_epoch=True, prog_bar=True)
0 commit comments