Skip to content

Commit 3576a17

Browse files
0.10.26
prediction in DataModule fixed
1 parent 9c535fa commit 3576a17

4 files changed

Lines changed: 45 additions & 8 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.10.25"
10+
version = "0.10.26"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/data/lightdatamodule.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,14 +124,22 @@ def setup(self, stage: Optional[str] = None) -> None:
124124
generator_test = torch.Generator().manual_seed(self.test_seed)
125125
self.data_test, _ = random_split(self.data_full, [test_size, full_train_size], generator=generator_test)
126126

127+
# if stage == "predict" or stage is None:
128+
# print(f"test_size, full_train_size: {test_size}, {full_train_size}")
129+
# generator_predict = torch.Generator().manual_seed(self.test_seed)
130+
# full_data_predict, _ = random_split(
131+
# self.data_full, [test_size, full_train_size], generator=generator_predict
132+
# )
133+
# # Only keep the features for prediction
134+
# self.data_predict = [x for x, _ in full_data_predict]
135+
136+
# Assign pred dataset for use in dataloader(s)
127137
if stage == "predict" or stage is None:
128-
print(f"test_size, full_train_size: {test_size}, {full_train_size}")
138+
# get test data aset as test_abs percent of the full dataset
129139
generator_predict = torch.Generator().manual_seed(self.test_seed)
130-
full_data_predict, _ = random_split(
140+
self.data_predict, _ = random_split(
131141
self.data_full, [test_size, full_train_size], generator=generator_predict
132142
)
133-
# Only keep the features for prediction
134-
self.data_predict = [x for x, _ in full_data_predict]
135143

136144
def train_dataloader(self) -> DataLoader:
137145
"""
@@ -209,4 +217,4 @@ def predict_dataloader(self) -> DataLoader:
209217
print(f"LightDataModule: predict_dataloader(). Predict set size: {len(self.data_predict)}")
210218
print(f"LightDataModule: predict_dataloader(). batch_size: {self.batch_size}")
211219
print(f"LightDataModule: predict_dataloader(). num_workers: {self.num_workers}")
212-
return DataLoader(self.data_predict, batch_size=self.batch_size, num_workers=self.num_workers)
220+
return DataLoader(self.data_predict, batch_size=len(self.data_predict), num_workers=self.num_workers)

src/spotPython/light/predictmodel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,8 @@ def predict_model(config: dict, fun_control: dict) -> Tuple[float, float]:
103103
trainer.fit(model=model, datamodule=dm)
104104

105105
dm.setup(stage="predict")
106-
# predictions = trainer.predict(model=model, datamodule=dm)
107-
predictions = trainer.predict(datamodule=dm)
106+
predictions = trainer.predict(model=model, datamodule=dm)
107+
# predictions = trainer.predict(datamodule=dm)
108108

109109
# # Load the last checkpoint
110110
# test_result = trainer.test(datamodule=dm, ckpt_path="last")

src/spotPython/light/regression/netlightregression.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,35 @@ def test_step(self, batch: tuple, batch_idx: int, prog_bar: bool = False) -> tor
239239
self.log("hp_metric", val_loss, prog_bar=prog_bar)
240240
return val_loss
241241

242+
def predict_step(self, batch: tuple, batch_idx: int, prog_bar: bool = False) -> torch.Tensor:
243+
"""
244+
Performs a single prediction step.
245+
246+
Args:
247+
batch (tuple): A tuple containing a batch of input data and labels.
248+
batch_idx (int): The index of the current batch.
249+
prog_bar (bool, optional): Whether to display the progress bar. Defaults to False.
250+
251+
Returns:
252+
torch.Tensor: A tensor containing the prediction for this batch.
253+
"""
254+
x, y = batch
255+
yhat = self(x)
256+
y = y.view(len(y), 1)
257+
yhat = yhat.view(len(yhat), 1)
258+
print(f"Predict step x: {x}")
259+
print(f"Predict step y: {y}")
260+
print(f"Predict step y_hat: {yhat}")
261+
# pred_loss = F.mse_loss(y_hat, y)
262+
# pred loss not registered
263+
# self.log("pred_loss", pred_loss, prog_bar=prog_bar)
264+
# self.log("hp_metric", pred_loss, prog_bar=prog_bar)
265+
# MisconfigurationException: You are trying to `self.log()`
266+
# but the loop's result collection is not registered yet.
267+
# This is most likely because you are trying to log in a `predict` hook, but it doesn't support logging.
268+
# If you want to manually log, please consider using `self.log_dict({'pred_loss': pred_loss})` instead.
269+
return (x, y, yhat)
270+
242271
def configure_optimizers(self) -> torch.optim.Optimizer:
243272
"""
244273
Configures the optimizer for the model.

0 commit comments

Comments
 (0)