Skip to content

Commit f232c94

Browse files
0.11.0
1 parent e56c4a0 commit f232c94

5 files changed

Lines changed: 20 additions & 5 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.10.69"
10+
version = "0.11.0"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/light/regression/netlightregression.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
class NetLightRegression(L.LightningModule):
1010
"""
11-
A LightningModule class for a regresssion neural network model.
11+
A LightningModule class for a regression neural network model.
1212
1313
Attributes:
1414
l1 (int):
@@ -34,7 +34,8 @@ class NetLightRegression(L.LightningModule):
3434
_L_out (int):
3535
The number of output classes.
3636
_torchmetric (str):
37-
The metric to use for the loss function, e.g., "mean_squared_error".
37+
The metric to use for the loss function. If `None`,
38+
then "mean_squared_error" is used.
3839
layers (nn.Sequential):
3940
The neural network model.
4041
@@ -124,7 +125,9 @@ def __init__(
124125
patience (int): The number of epochs to wait before early stopping.
125126
_L_in (int): The number of input features. Not a hyperparameter, but needed to create the network.
126127
_L_out (int): The number of output classes. Not a hyperparameter, but needed to create the network.
127-
_torchmetric (str): The metric to use for the loss function, e.g., "mean_squared_error".
128+
_torchmetric (str):
129+
The metric to use for the loss function. If `None`,
130+
then "mean_squared_error" is used.
128131
129132
Returns:
130133
(NoneType): None
@@ -141,7 +144,10 @@ def __init__(
141144
#
142145
self._L_in = _L_in
143146
self._L_out = _L_out
147+
if _torchmetric is None:
148+
_torchmetric = "mean_squared_error"
144149
self._torchmetric = _torchmetric
150+
self.metric = getattr(torchmetrics.functional.regression, _torchmetric)
145151
# _L_in and _L_out are not hyperparameters, but are needed to create the network
146152
# _torchmetric is not a hyperparameter, but is needed to calculate the loss
147153
self.save_hyperparameters(ignore=["_L_in", "_L_out", "_torchmetric"])

src/spotPython/light/regression/netlightregression2.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ def __init__(
141141
#
142142
self._L_in = _L_in
143143
self._L_out = _L_out
144+
if _torchmetric is None:
145+
_torchmetric = "mean_squared_error"
146+
self._torchmetric = _torchmetric
144147
self.metric = getattr(torchmetrics.functional.regression, _torchmetric)
145148
# _L_in and _L_out are not hyperparameters, but are needed to create the network
146149
# _torchmetric is not a hyperparameter, but is needed to calculate the loss

src/spotPython/light/regression/rnnlightregression.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,10 @@ def __init__(
139139
#
140140
self._L_in = _L_in
141141
self._L_out = _L_out
142+
if _torchmetric is None:
143+
_torchmetric = "mean_squared_error"
142144
self._torchmetric = _torchmetric
145+
self.metric = getattr(torchmetrics.functional.regression, _torchmetric)
143146
# _L_in and _L_out are not hyperparameters, but are needed to create the network
144147
# _torchmetric is not a hyperparameter, but is needed to calculate the loss
145148
self.save_hyperparameters(ignore=["_L_in", "_L_out", "_torchmetric"])

src/spotPython/light/regression/transformerlightregression.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,14 @@ def __init__(
147147
#
148148
self._L_in = _L_in
149149
self._L_out = _L_out
150+
if _torchmetric is None:
151+
_torchmetric = "mean_squared_error"
152+
self._torchmetric = _torchmetric
150153
self.metric = getattr(torchmetrics.functional.regression, _torchmetric)
151-
self.d_mult = d_mult
152154
# _L_in and _L_out are not hyperparameters, but are needed to create the network
153155
# _torchmetric is not a hyperparameter, but is needed to calculate the loss
154156
self.save_hyperparameters(ignore=["_L_in", "_L_out", "_torchmetric"])
157+
self.d_mult = d_mult
155158
# set dummy input array for Tensorboard Graphs
156159
# set log_graph=True in Trainer to see the graph (in traintest.py)
157160
self.example_input_array = torch.zeros((batch_size, self._L_in))

0 commit comments

Comments
 (0)