Skip to content

Commit e1f896b

Browse files
0.15.36
gpu strategies and related parameters
1 parent 9d6a9f6 commit e1f896b

7 files changed

Lines changed: 27 additions & 3 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.15.35"
10+
version = "0.15.36"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/light/cvmodel.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ def cv_model(config: dict, fun_control: dict) -> float:
7777
max_epochs=model.hparams.epochs,
7878
accelerator=fun_control["accelerator"],
7979
devices=fun_control["devices"],
80+
strategy=fun_control["strategy"],
81+
num_nodes=fun_control["num_nodes"],
82+
precision=fun_control["precision"],
8083
logger=TensorBoardLogger(
8184
save_dir=fun_control["TENSORBOARD_PATH"],
8285
version=config_id,

src/spotpython/light/predictmodel.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ def predict_model(config: dict, fun_control: dict) -> Tuple[float, float]:
8282
max_epochs=model.hparams.epochs,
8383
accelerator=fun_control["accelerator"],
8484
devices=fun_control["devices"],
85+
strategy=fun_control["strategy"],
86+
num_nodes=fun_control["num_nodes"],
87+
precision=fun_control["precision"],
8588
logger=TensorBoardLogger(
8689
save_dir=fun_control["TENSORBOARD_PATH"],
8790
version=config_id,

src/spotpython/light/testmodel.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ def test_model(config: dict, fun_control: dict) -> Tuple[float, float]:
8383
max_epochs=model.hparams.epochs,
8484
accelerator=fun_control["accelerator"],
8585
devices=fun_control["devices"],
86+
strategy=fun_control["strategy"],
87+
num_nodes=fun_control["num_nodes"],
88+
precision=fun_control["precision"],
8689
logger=TensorBoardLogger(
8790
save_dir=fun_control["TENSORBOARD_PATH"],
8891
version=config_id,

src/spotpython/light/trainmodel.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ def train_model(config: dict, fun_control: dict, timestamp: bool = True) -> floa
119119
max_epochs=model.hparams.epochs,
120120
accelerator=fun_control["accelerator"],
121121
devices=fun_control["devices"],
122+
strategy=fun_control["strategy"],
123+
num_nodes=fun_control["num_nodes"],
124+
precision=fun_control["precision"],
122125
logger=TensorBoardLogger(
123126
save_dir=fun_control["TENSORBOARD_PATH"],
124127
version=config_id,

src/spotpython/utils/init.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def fun_control_init(
3535
db_dict_name=None,
3636
design=None,
3737
device=None,
38-
devices=1,
38+
devices="auto",
3939
enable_progress_bar=False,
4040
EXPERIMENT_NAME=None,
4141
eval=None,
@@ -55,9 +55,11 @@ def fun_control_init(
5555
n_samples=None,
5656
n_total=None,
5757
num_workers=0,
58+
num_nodes=1,
5859
ocba_delta=0,
5960
oml_grace_period=None,
6061
optimizer=None,
62+
precision="32",
6163
prep_model=None,
6264
prep_model_name=None,
6365
progress_file=None,
@@ -70,6 +72,7 @@ def fun_control_init(
7072
show_progress=True,
7173
shuffle=None,
7274
sigma=0.0,
75+
strategy="auto",
7376
surrogate=None,
7477
target_column=None,
7578
target_type=None,
@@ -181,6 +184,8 @@ def fun_control_init(
181184
The number of samples in the dataset. Default is None.
182185
n_total (int):
183186
The total number of samples in the dataset. Default is None.
187+
num_nodes (int):
188+
The number of GPU nodes to use for the training/validation/testing. Default is 1.
184189
num_workers (int):
185190
The number of workers to use for the data loading. Default is 0.
186191
ocba_delta (int):
@@ -190,6 +195,8 @@ def fun_control_init(
190195
The grace period for the OML algorithm. Default is None.
191196
optimizer (object):
192197
The optimizer object used for the search on surrogate. Default is None.
198+
precision (str):
199+
The precision of the data. Default is "32". Can be e.g., "16-mixed" or "16-true".
193200
PREFIX (str):
194201
The prefix of the experiment name. If the PREFIX is not None, a spotWriter
195202
that us an instance of a SummaryWriter(), is created. Default is "00".
@@ -221,6 +228,8 @@ def fun_control_init(
221228
Whether the data were shuffled or not. Default is None.
222229
surrogate (object):
223230
The surrogate model object. Default is None.
231+
strategy (str):
232+
The strategy to use. Default is "auto".
224233
target_column (str):
225234
The name of the target column. Default is None.
226235
target_type (str):
@@ -393,11 +402,13 @@ def fun_control_init(
393402
"n_points": n_points,
394403
"n_samples": n_samples,
395404
"n_total": n_total,
405+
"num_nodes": num_nodes,
396406
"num_workers": num_workers,
397407
"ocba_delta": ocba_delta,
398408
"oml_grace_period": oml_grace_period,
399409
"optimizer": optimizer,
400410
"path": None,
411+
"precision": precision,
401412
"prep_model": prep_model,
402413
"prep_model_name": prep_model_name,
403414
"progress_file": progress_file,
@@ -413,6 +424,7 @@ def fun_control_init(
413424
"shuffle": shuffle,
414425
"sigma": sigma,
415426
"spot_tensorboard_path": spot_tensorboard_path,
427+
"strategy": strategy,
416428
"target_column": target_column,
417429
"target_type": target_type,
418430
"task": task,

test/test_get_activations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_get_activations():
4747
_ = DataLoader(dataset, batch_size=batch_size, shuffle=False)
4848

4949
# Get the activations
50-
activations = get_activations(model, fun_control=fun_control, batch_size=batch_size, device="cpu")
50+
activations, _ = get_activations(model, fun_control=fun_control, batch_size=batch_size, device="cpu")
5151

5252
# Assert that the activations dictionary is not empty
5353
assert len(activations) > 0

0 commit comments

Comments
 (0)