88
99class NetLightRegression (L .LightningModule ):
1010 """
11- A LightningModule class for a regresssion neural network model.
11+ A LightningModule class for a regression neural network model.
1212
1313 Attributes:
1414 l1 (int):
@@ -34,7 +34,8 @@ class NetLightRegression(L.LightningModule):
3434 _L_out (int):
3535 The number of output classes.
3636 _torchmetric (str):
37- The metric to use for the loss function, e.g., "mean_squared_error".
37+ The metric to use for the loss function. If `None`,
38+ then "mean_squared_error" is used.
3839 layers (nn.Sequential):
3940 The neural network model.
4041
@@ -124,7 +125,9 @@ def __init__(
124125 patience (int): The number of epochs to wait before early stopping.
125126 _L_in (int): The number of input features. Not a hyperparameter, but needed to create the network.
126127 _L_out (int): The number of output classes. Not a hyperparameter, but needed to create the network.
127- _torchmetric (str): The metric to use for the loss function, e.g., "mean_squared_error".
128+ _torchmetric (str):
129+ The metric to use for the loss function. If `None`,
130+ then "mean_squared_error" is used.
128131
129132 Returns:
130133 (NoneType): None
@@ -141,7 +144,10 @@ def __init__(
141144 #
142145 self ._L_in = _L_in
143146 self ._L_out = _L_out
147+ if _torchmetric is None :
148+ _torchmetric = "mean_squared_error"
144149 self ._torchmetric = _torchmetric
150+ self .metric = getattr (torchmetrics .functional .regression , _torchmetric )
145151 # _L_in and _L_out are not hyperparameters, but are needed to create the network
146152 # _torchmetric is not a hyperparameter, but is needed to calculate the loss
147153 self .save_hyperparameters (ignore = ["_L_in" , "_L_out" , "_torchmetric" ])
0 commit comments