Skip to content

Commit 14fb2a3

Browse files
0.10.68
torch lightning regression models require a new argument: "_torchmetric"
1 parent 563a01d commit 14fb2a3

15 files changed

Lines changed: 68 additions & 21 deletions

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.10.67"
10+
version = "0.10.68"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/light/cvmodel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def cv_model(config: dict, fun_control: dict) -> float:
3939
"""
4040
_L_in = fun_control["_L_in"]
4141
_L_out = fun_control["_L_out"]
42+
_torchmetric = fun_control["_torchmetric"]
4243
if fun_control["enable_progress_bar"] is None:
4344
enable_progress_bar = False
4445
else:
@@ -52,7 +53,7 @@ def cv_model(config: dict, fun_control: dict) -> float:
5253
for k in range(num_folds):
5354
print("k:", k)
5455

55-
model = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out)
56+
model = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out, _torchmetric=_torchmetric)
5657
initialization = config["initialization"]
5758
if initialization == "Xavier":
5859
xavier_init(model)

src/spotPython/light/predictmodel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def predict_model(config: dict, fun_control: dict) -> Tuple[float, float]:
5454
"""
5555
_L_in = fun_control["_L_in"]
5656
_L_out = fun_control["_L_out"]
57+
_torchmetric = fun_control["_torchmetric"]
5758
if fun_control["enable_progress_bar"] is None:
5859
enable_progress_bar = False
5960
else:
@@ -72,7 +73,7 @@ def predict_model(config: dict, fun_control: dict) -> Tuple[float, float]:
7273
)
7374
dm.setup(stage="train")
7475
# Init model from datamodule's attributes
75-
model = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out)
76+
model = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out, _torchmetric=_torchmetric)
7677
initialization = config["initialization"]
7778
if initialization == "Xavier":
7879
xavier_init(model)

src/spotPython/light/regression/netlightregression.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch.nn.functional as F
44
from torch import nn
55
from spotPython.hyperparameters.optimizer import optimizer_handler
6+
import torchmetrics.functional.regression
67

78

89
class NetLightRegression(L.LightningModule):
@@ -32,6 +33,8 @@ class NetLightRegression(L.LightningModule):
3233
The number of input features.
3334
_L_out (int):
3435
The number of output classes.
36+
_torchmetric (str):
37+
The metric to use for the loss function, e.g., "mean_squared_error".
3538
layers (nn.Sequential):
3639
The neural network model.
3740
@@ -104,6 +107,7 @@ def __init__(
104107
patience: int,
105108
_L_in: int,
106109
_L_out: int,
110+
_torchmetric: str,
107111
):
108112
"""
109113
Initializes the NetLightRegression object.
@@ -120,6 +124,7 @@ def __init__(
120124
patience (int): The number of epochs to wait before early stopping.
121125
_L_in (int): The number of input features. Not a hyperparameter, but needed to create the network.
122126
_L_out (int): The number of output classes. Not a hyperparameter, but needed to create the network.
127+
_torchmetric (str): The metric to use for the loss function, e.g., "mean_squared_error".
123128
124129
Returns:
125130
(NoneType): None
@@ -136,8 +141,10 @@ def __init__(
136141
#
137142
self._L_in = _L_in
138143
self._L_out = _L_out
144+
self._torchmetric = _torchmetric
139145
# _L_in and _L_out are not hyperparameters, but are needed to create the network
140-
self.save_hyperparameters(ignore=["_L_in", "_L_out"])
146+
# _torchmetric is not a hyperparameter, but is needed to calculate the loss
147+
self.save_hyperparameters(ignore=["_L_in", "_L_out", "_torchmetric"])
141148
# set dummy input array for Tensorboard Graphs
142149
# set log_graph=True in Trainer to see the graph (in traintest.py)
143150
self.example_input_array = torch.zeros((batch_size, self._L_in))
@@ -189,7 +196,9 @@ def training_step(self, batch: tuple) -> torch.Tensor:
189196
x, y = batch
190197
y = y.view(len(y), 1)
191198
y_hat = self(x)
192-
val_loss = F.mse_loss(y_hat, y)
199+
# val_loss = F.mse_loss(y_hat, y)
200+
metric = getattr(torchmetrics.functional.regression, self._torchmetric)
201+
val_loss = metric(y_hat, y)
193202
# mae_loss = F.l1_loss(y_hat, y)
194203
# self.log("train_loss", val_loss, on_step=True, on_epoch=True, prog_bar=True)
195204
# self.log("train_mae_loss", mae_loss, on_step=True, on_epoch=True, prog_bar=True)

src/spotPython/light/regression/netlightregression2.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import lightning as L
22
import torch
3-
import torch.nn.functional as F
43
from torch import nn
54
from spotPython.hyperparameters.optimizer import optimizer_handler
65
from spotPython.utils.math import generate_div2_list
6+
import torchmetrics.functional.regression
77

88

99
class NetLightRegression2(L.LightningModule):
@@ -33,6 +33,8 @@ class NetLightRegression2(L.LightningModule):
3333
The number of input features.
3434
_L_out (int):
3535
The number of output classes.
36+
_torchmetric (str):
37+
The metric to use for the loss function, e.g., "mean_squared_error".
3638
layers (nn.Sequential):
3739
The neural network model.
3840
@@ -105,6 +107,7 @@ def __init__(
105107
patience: int,
106108
_L_in: int,
107109
_L_out: int,
110+
_torchmetric: str,
108111
):
109112
"""
110113
Initializes the NetLightRegression2 object.
@@ -121,6 +124,7 @@ def __init__(
121124
patience (int): The number of epochs to wait before early stopping.
122125
_L_in (int): The number of input features. Not a hyperparameter, but needed to create the network.
123126
_L_out (int): The number of output classes. Not a hyperparameter, but needed to create the network.
127+
_torchmetric (str): The metric to use for the loss function, e.g., "mean_squared_error".
124128
125129
Returns:
126130
(NoneType): None
@@ -137,8 +141,10 @@ def __init__(
137141
#
138142
self._L_in = _L_in
139143
self._L_out = _L_out
144+
self.metric = getattr(torchmetrics.functional.regression, _torchmetric)
140145
# _L_in and _L_out are not hyperparameters, but are needed to create the network
141-
self.save_hyperparameters(ignore=["_L_in", "_L_out"])
146+
# _torchmetric is not a hyperparameter, but is needed to calculate the loss
147+
self.save_hyperparameters(ignore=["_L_in", "_L_out", "_torchmetric"])
142148
# set dummy input array for Tensorboard Graphs
143149
# set log_graph=True in Trainer to see the graph (in traintest.py)
144150
self.example_input_array = torch.zeros((batch_size, self._L_in))
@@ -196,7 +202,8 @@ def _calculate_loss(self, batch, mode="train"):
196202
x, y = batch
197203
y = y.view(len(y), 1)
198204
y_hat = self(x)
199-
loss = F.mse_loss(y_hat, y)
205+
# loss = F.mse_loss(y_hat, y)
206+
loss = self.metric(y_hat, y)
200207
if mode == "val" or mode == "test":
201208
self.log(f"{mode}_loss", loss, prog_bar=True)
202209
self.log("hp_metric", loss, prog_bar=True)

src/spotPython/light/regression/rnnlightregression.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch.nn.functional as F
44
from torch import nn
55
from spotPython.hyperparameters.optimizer import optimizer_handler
6+
import torchmetrics.functional.regression
67

78

89
class RNNLightRegression(L.LightningModule):
@@ -32,6 +33,8 @@ class RNNLightRegression(L.LightningModule):
3233
The number of input features.
3334
_L_out (int):
3435
The number of output classes.
36+
_torchmetric (str):
37+
The metric to use for the loss function, e.g., "mean_squared_error".
3538
layers (nn.Sequential):
3639
The neural network model.
3740
@@ -104,6 +107,7 @@ def __init__(
104107
patience: int,
105108
_L_in: int,
106109
_L_out: int,
110+
_torchmetric: str,
107111
):
108112
"""
109113
Initializes the NetLightRegression object.
@@ -120,6 +124,7 @@ def __init__(
120124
patience (int): The number of epochs to wait before early stopping.
121125
_L_in (int): The number of input features. Not a hyperparameter, but needed to create the network.
122126
_L_out (int): The number of output classes. Not a hyperparameter, but needed to create the network.
127+
_torchmetric (str): The metric to use for the loss function, e.g., "mean_squared_error".
123128
124129
Returns:
125130
(NoneType): None
@@ -133,8 +138,10 @@ def __init__(
133138
#
134139
self._L_in = _L_in
135140
self._L_out = _L_out
141+
self._torchmetric = _torchmetric
136142
# _L_in and _L_out are not hyperparameters, but are needed to create the network
137-
self.save_hyperparameters(ignore=["_L_in", "_L_out"])
143+
# _torchmetric is not a hyperparameter, but is needed to calculate the loss
144+
self.save_hyperparameters(ignore=["_L_in", "_L_out", "_torchmetric"])
138145
# set dummy input array for Tensorboard Graphs
139146
# set log_graph=True in Trainer to see the graph (in traintest.py)
140147
self.example_input_array = torch.zeros((batch_size, self._L_in))
@@ -224,7 +231,9 @@ def training_step(self, batch: tuple, prog_bar: bool = False) -> torch.Tensor:
224231
# Note: the number of rows in x is equal to the number of rows in y
225232
y_hat = self(x)
226233
# Note: the number of rows in y_hat is equal to the number of rows in y
227-
train_loss = F.mse_loss(y_hat, y)
234+
# train_loss = F.mse_loss(y_hat, y)
235+
metric = getattr(torchmetrics.functional.regression, self._torchmetric)
236+
train_loss = metric(y_hat, y)
228237
# mae_loss = F.l1_loss(y_hat, y)
229238
# self.log("train_loss", val_loss, on_step=True, on_epoch=True, prog_bar=True)
230239
# self.log("train_mae_loss", mae_loss, on_step=True, on_epoch=True, prog_bar=True)

src/spotPython/light/regression/transformerlightregression.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from spotPython.light.transformer.skiplinear import SkipLinear
77
from spotPython.light.transformer.positionalEncoding import PositionalEncoding
88
from spotPython.utils.math import generate_div2_list
9+
import torchmetrics.functional.regression
910

1011

1112
class TransformerLightRegression(L.LightningModule):
@@ -35,6 +36,8 @@ class TransformerLightRegression(L.LightningModule):
3536
The number of input features.
3637
_L_out (int):
3738
The number of output classes.
39+
_torchmetric (str):
40+
The metric to use for the loss function, e.g., "mean_squared_error".
3841
layers (nn.Sequential):
3942
The neural network model.
4043
@@ -110,6 +113,7 @@ def __init__(
110113
patience: int,
111114
_L_in: int,
112115
_L_out: int,
116+
_torchmetric: str,
113117
):
114118
"""
115119
Initializes the TransformerLightRegression object.
@@ -126,6 +130,7 @@ def __init__(
126130
patience (int): The number of epochs to wait before early stopping.
127131
_L_in (int): The number of input features. Not a hyperparameter, but needed to create the network.
128132
_L_out (int): The number of output classes. Not a hyperparameter, but needed to create the network.
133+
_torchmetric (str): The metric to use for the loss function, e.g., "mean_squared_error".
129134
130135
Returns:
131136
(NoneType): None
@@ -142,9 +147,11 @@ def __init__(
142147
#
143148
self._L_in = _L_in
144149
self._L_out = _L_out
150+
self.metric = getattr(torchmetrics.functional.regression, _torchmetric)
145151
self.d_mult = d_mult
146152
# _L_in and _L_out are not hyperparameters, but are needed to create the network
147-
self.save_hyperparameters(ignore=["_L_in", "_L_out"])
153+
# _torchmetric is not a hyperparameter, but is needed to calculate the loss
154+
self.save_hyperparameters(ignore=["_L_in", "_L_out", "_torchmetric"])
148155
# set dummy input array for Tensorboard Graphs
149156
# set log_graph=True in Trainer to see the graph (in traintest.py)
150157
self.example_input_array = torch.zeros((batch_size, self._L_in))

src/spotPython/light/testmodel.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ def test_model(config: dict, fun_control: dict) -> Tuple[float, float]:
3939
import spotPython.light.testmodel as tm
4040
fun_control = fun_control_init(
4141
_L_in=10,
42-
_L_out=1,)
42+
_L_out=1,
43+
_torchmetric="mean_squared_error")
4344
dataset = Diabetes()
4445
set_control_key_value(control_dict=fun_control,
4546
key="data_set",
@@ -54,6 +55,7 @@ def test_model(config: dict, fun_control: dict) -> Tuple[float, float]:
5455
"""
5556
_L_in = fun_control["_L_in"]
5657
_L_out = fun_control["_L_out"]
58+
_torchmetric = fun_control["_torchmetric"]
5759
if fun_control["enable_progress_bar"] is None:
5860
enable_progress_bar = False
5961
else:
@@ -72,7 +74,7 @@ def test_model(config: dict, fun_control: dict) -> Tuple[float, float]:
7274
)
7375
dm.setup()
7476
# Init model from datamodule's attributes
75-
model = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out)
77+
model = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out, _torchmetric=_torchmetric)
7678
initialization = config["initialization"]
7779
if initialization == "Xavier":
7880
xavier_init(model)

src/spotPython/light/trainmodel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,15 @@ def train_model(config: dict, fun_control: dict) -> float:
6767
"""
6868
_L_in = fun_control["_L_in"]
6969
_L_out = fun_control["_L_out"]
70+
_torchmetric = fun_control["_torchmetric"]
7071
if fun_control["enable_progress_bar"] is None:
7172
enable_progress_bar = False
7273
else:
7374
enable_progress_bar = fun_control["enable_progress_bar"]
7475
# config id is unique. Since the model is not loaded from a checkpoint,
7576
# the config id is generated here with a timestamp.
7677
config_id = generate_config_id(config, timestamp=True)
77-
model = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out)
78+
model = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out, _torchmetric=_torchmetric)
7879
initialization = config["initialization"]
7980
if initialization == "Xavier":
8081
xavier_init(model)

src/spotPython/utils/init.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
def fun_control_init(
1212
_L_in=None,
1313
_L_out=None,
14+
_torchmetric=None,
1415
PREFIX=None,
1516
TENSORBOARD_CLEAN=False,
1617
SUMMARY_WRITER=True,
@@ -69,6 +70,10 @@ def fun_control_init(
6970
The number of input features.
7071
_L_out (int):
7172
The number of output features.
73+
_torchmetric (str):
74+
The metric to be used by the Lighting Trainer.
75+
For example "mean_squared_error",
76+
see https://lightning.ai/docs/torchmetrics/stable/regression/mean_squared_error.html
7277
accelerator (str):
7378
The accelerator to be used by the Lighting Trainer.
7479
It can be either "auto", "dp", "ddp", "ddp2", "ddp_spawn", "ddp_cpu", "gpu", "tpu".
@@ -294,6 +299,7 @@ def fun_control_init(
294299
"TENSORBOARD_PATH": TENSORBOARD_PATH,
295300
"_L_in": _L_in,
296301
"_L_out": _L_out,
302+
"_torchmetric": _torchmetric,
297303
"accelerator": accelerator,
298304
"converters": converters,
299305
"core_model": core_model,

0 commit comments

Comments
 (0)