Skip to content

Commit 622220c

Browse files
0.15.11
surrogate_control_init accepts n_theta="anisotropic" which is the new default print output reduced
1 parent c7ba86d commit 622220c

6 files changed

Lines changed: 44 additions & 24 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.15.10"
10+
version = "0.15.11"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/data/lightdatamodule.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def __init__(
8484
test_seed: int = 42,
8585
num_workers: int = 0,
8686
scaler: Optional[object] = None,
87+
verbosity: int = 0,
8788
):
8889
super().__init__()
8990
self.batch_size = batch_size
@@ -92,6 +93,7 @@ def __init__(
9293
self.test_seed = test_seed
9394
self.num_workers = num_workers
9495
self.scaler = scaler
96+
self.verbosity = verbosity
9597

9698
def prepare_data(self) -> None:
9799
"""Prepares the data for use."""
@@ -134,15 +136,18 @@ def setup(self, stage: Optional[str] = None) -> None:
134136
val_size = int(full_train_size * test_size / len(self.data_full))
135137
train_size = full_train_size - val_size
136138

137-
print(f"LightDataModule.setup(): stage: {stage}")
138-
# print(f"LightDataModule setup(): full_train_size: {full_train_size}")
139-
# print(f"LightDataModule setup(): val_size: {val_size}")
140-
# print(f"LightDataModule setup(): train_size: {train_size}")
141-
# print(f"LightDataModule setup(): test_size: {test_size}")
139+
if self.verbosity > 0:
140+
print(f"LightDataModule.setup(): stage: {stage}")
141+
if self.verbosity > 1:
142+
print(f"LightDataModule setup(): full_train_size: {full_train_size}")
143+
print(f"LightDataModule setup(): val_size: {val_size}")
144+
print(f"LightDataModule setup(): train_size: {train_size}")
145+
print(f"LightDataModule setup(): test_size: {test_size}")
142146

143147
# Assign train/val datasets for use in dataloaders
144148
if stage == "fit" or stage is None:
145-
print(f"train_size: {train_size}, val_size: {val_size} used for train & val data.")
149+
if self.verbosity > 0:
150+
print(f"train_size: {train_size}, val_size: {val_size} used for train & val data.")
146151
generator_fit = torch.Generator().manual_seed(self.test_seed)
147152
self.data_train, self.data_val, _ = random_split(
148153
self.data_full, [train_size, val_size, test_size], generator=generator_fit
@@ -151,7 +156,8 @@ def setup(self, stage: Optional[str] = None) -> None:
151156
# Fit the scaler on training data and transform both train and val data
152157
scaler_train_data = torch.stack([self.data_train[i][0] for i in range(len(self.data_train))]).squeeze(1)
153158
# train_val_data = self.data_train[:,0]
154-
print(scaler_train_data.shape)
159+
if self.verbosity > 0:
160+
print(scaler_train_data.shape)
155161
self.scaler.fit(scaler_train_data)
156162
self.data_train = [(self.scaler.transform(data), target) for data, target in self.data_train]
157163
data_tensors_train = [data.clone().detach() for data, target in self.data_train]
@@ -167,7 +173,8 @@ def setup(self, stage: Optional[str] = None) -> None:
167173

168174
# Assign test dataset for use in dataloader(s)
169175
if stage == "test" or stage is None:
170-
print(f"test_size: {test_size} used for test dataset.")
176+
if self.verbosity > 0:
177+
print(f"test_size: {test_size} used for test dataset.")
171178
# get test data set as test_abs percent of the full dataset
172179
generator_test = torch.Generator().manual_seed(self.test_seed)
173180
self.data_test, _ = random_split(self.data_full, [test_size, full_train_size], generator=generator_test)
@@ -190,7 +197,8 @@ def setup(self, stage: Optional[str] = None) -> None:
190197

191198
# Assign pred dataset for use in dataloader(s)
192199
if stage == "predict" or stage is None:
193-
print(f"test_size: {test_size} used for predict dataset.")
200+
if self.verbosity > 0:
201+
print(f"test_size: {test_size} used for predict dataset.")
194202
# get test data set as test_abs percent of the full dataset
195203
generator_predict = torch.Generator().manual_seed(self.test_seed)
196204
self.data_predict, _ = random_split(
@@ -223,7 +231,8 @@ def train_dataloader(self) -> DataLoader:
223231
Training set size: 3
224232
225233
"""
226-
print(f"LightDataModule.train_dataloader(). data_train size: {len(self.data_train)}")
234+
if self.verbosity > 0:
235+
print(f"LightDataModule.train_dataloader(). data_train size: {len(self.data_train)}")
227236
# print(f"LightDataModule: train_dataloader(). batch_size: {self.batch_size}")
228237
# print(f"LightDataModule: train_dataloader(). num_workers: {self.num_workers}")
229238
# apply fit_transform to the training data
@@ -247,7 +256,8 @@ def val_dataloader(self) -> DataLoader:
247256
print(f"Training set size: {len(data_module.data_val)}")
248257
Training set size: 3
249258
"""
250-
print(f"LightDataModule.val_dataloader(). Val. set size: {len(self.data_val)}")
259+
if self.verbosity > 0:
260+
print(f"LightDataModule.val_dataloader(). Val. set size: {len(self.data_val)}")
251261
# print(f"LightDataModule: val_dataloader(). batch_size: {self.batch_size}")
252262
# print(f"LightDataModule: val_dataloader(). num_workers: {self.num_workers}")
253263
# apply fit_transform to the val data
@@ -272,7 +282,8 @@ def test_dataloader(self) -> DataLoader:
272282
Test set size: 6
273283
274284
"""
275-
print(f"LightDataModule.test_dataloader(). Test set size: {len(self.data_test)}")
285+
if self.verbosity > 0:
286+
print(f"LightDataModule.test_dataloader(). Test set size: {len(self.data_test)}")
276287
# print(f"LightDataModule: test_dataloader(). batch_size: {self.batch_size}")
277288
# print(f"LightDataModule: test_dataloader(). num_workers: {self.num_workers}")
278289
# apply fit_transform to the val data
@@ -297,7 +308,8 @@ def predict_dataloader(self) -> DataLoader:
297308
Predict set size: 6
298309
299310
"""
300-
print(f"LightDataModule.predict_dataloader(). Predict set size: {len(self.data_predict)}")
311+
if self.verbosity > 0:
312+
print(f"LightDataModule.predict_dataloader(). Predict set size: {len(self.data_predict)}")
301313
# print(f"LightDataModule: predict_dataloader(). batch_size: {self.batch_size}")
302314
# print(f"LightDataModule: predict_dataloader(). num_workers: {self.num_workers}")
303315
# apply fit_transform to the val data

src/spotpython/light/cvmodel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ def cv_model(config: dict, fun_control: dict) -> float:
101101
trainer.fit(model=model, datamodule=dm)
102102
# Test best model on validation and test set
103103
# result = trainer.validate(model=model, datamodule=dm, ckpt_path="last")
104-
score = trainer.validate(model=model, datamodule=dm)
104+
verbose = fun_control["verbosity"] > 0
105+
score = trainer.validate(model=model, datamodule=dm, verbose=verbose)
105106
# unlist the result (from a list of one dict)
106107
score = score[0]
107108
print(f"train_model result: {score}")

src/spotpython/light/trainmodel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@ def train_model(config: dict, fun_control: dict, timestamp: bool = True) -> floa
140140
trainer.fit(model=model, datamodule=dm)
141141
# Test best model on validation and test set
142142
# result = trainer.validate(model=model, datamodule=dm, ckpt_path="last")
143-
result = trainer.validate(model=model, datamodule=dm)
143+
verbose = fun_control["verbosity"] > 0
144+
result = trainer.validate(model=model, datamodule=dm, verbose=verbose)
144145
# unlist the result (from a list of one dict)
145146
result = result[0]
146147
print(f"train_model result: {result}")

src/spotpython/spot/spot.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -306,11 +306,16 @@ def __init__(
306306
if self.surrogate_control["model_optimizer"] is None or optimizer is not None:
307307
self.surrogate_control.update({"model_optimizer": self.optimizer})
308308

309-
# If self.surrogate_control["n_theta"] > 1, use k theta values:
310-
if self.surrogate_control["n_theta"] > 1:
311-
surrogate_control.update({"n_theta": self.k})
312-
else:
313-
surrogate_control.update({"n_theta": 1})
309+
# if self.surrogate_control["n_theta"] is a string and == isotropic, use 1 theta value:
310+
if isinstance(self.surrogate_control["n_theta"], str):
311+
if self.surrogate_control["n_theta"] == "anisotropic":
312+
surrogate_control.update({"n_theta": self.k})
313+
else:
314+
# case "isotropic":
315+
surrogate_control.update({"n_theta": 1})
316+
if isinstance(self.surrogate_control["n_theta"], int):
317+
if self.surrogate_control["n_theta"] > 1:
318+
surrogate_control.update({"n_theta": self.k})
314319

315320
# If no surrogate model is specified, use the internal
316321
# spotpython kriging surrogate:

src/spotpython/utils/init.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,7 @@ def surrogate_control_init(
599599
model_fun_evals=10000,
600600
min_theta=-3.0,
601601
max_theta=2.0,
602-
n_theta=1,
602+
n_theta="anisotropic",
603603
p_val=2.0,
604604
n_p=1,
605605
optim_p=False,
@@ -628,8 +628,9 @@ def surrogate_control_init(
628628
Whether the objective function is noisy or not. If Kriging, then a nugget is added.
629629
Default is False. Note: Will be set in the Spot class.
630630
n_theta (int):
631-
The number of theta values. If larger than 1, then the k theta values are
632-
used, where k is the problem dimension. Default is 1.
631+
The number of theta values. If larger than 1 or set to the string "anisotropic",
632+
then the k theta values are used, where k is the problem dimension.
633+
This is handled in spot.py. Default is "anisotropic".
633634
p_val (float):
634635
p value. Used as an initial value if optim_p = True. Otherwise as a constant. Defaults to 2.0.
635636
n_p (int):

0 commit comments

Comments
 (0)